From a2969aa8b6c3da78fe31a54a589c393480eabe77 Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Wed, 29 Nov 2023 09:36:40 -0800 Subject: [PATCH 01/75] Disable transpose device op for MI300 (#1050) * added working example for 5D input using 1D kernel * example with 5D input tensor and 2d kernel - not working: issues with arguments * added updated version of 3d device op - changed descriptors/dims * added example file to check kernel * fixed descriptor and isSupportedArgument stride problem * added and modified kernel for 3d - updated tids/loop * adding some more 5d example files * fixed some issues * changes made for testing * working version: fixed error in stride for A, still a bit inefficient * cleaned up formatting/comments * updating formatting * more formatting fixes * fixing cmake, adding back gpu targets in cmake script * adding client example * added instances for client example * fixed errors in client example * implemented client ex with device_elementwise.hpp and device_elementwise_3d_impl.hpp * removed extra files * minor formatting and naming fixes * adding test files and profiler * fixing minor error * minor fix * removed unneccesary comments, renamed files * updated instance list for client example, added different layout example * removing instances * fixed error in instance generation * remove comments * update profiler and client example tensor layouts * fixed errors in test/profiler * updated vector dim access to enable vector load * updated test/profiler files * updated example with 1d kernel * updating profiler * renamed files * disabled device op for MI300 * skip elementwise_permute_2d on gfx94x * Update CMakeLists.txt * fixing CMake - disabling some GPU targets --------- Co-authored-by: Jing Zhang Co-authored-by: Jing Zhang Co-authored-by: zjing14 --- example/44_elementwise_permute/CMakeLists.txt | 4 +- .../impl/device_elementwise_3d_impl.hpp | 7 ++ profiler/src/profile_transpose.cpp | 85 +++++++++++++++++++ 3 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 profiler/src/profile_transpose.cpp diff --git a/example/44_elementwise_permute/CMakeLists.txt b/example/44_elementwise_permute/CMakeLists.txt index c68e4cde5b..a963399dc7 100644 --- a/example/44_elementwise_permute/CMakeLists.txt +++ b/example/44_elementwise_permute/CMakeLists.txt @@ -5,4 +5,6 @@ add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permu add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permute_4D_fp32_col.cpp) add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp) add_example_executable(example_elementwise_permute elementwise_permute.cpp) -add_example_executable(example_elementwise_permute_3d elementwise_permute_3d.cpp) +if((NOT GPU_TARGETS MATCHES "gfx940") AND (NOT GPU_TARGETS MATCHES "gfx941") AND (NOT GPU_TARGETS MATCHES "gfx942")) + add_example_executable(example_elementwise_permute_3d elementwise_permute_3d.cpp) +endif() diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp index 147efc45ab..67b6f87465 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp @@ -13,6 +13,7 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/stream_utility.hpp" namespace ck { @@ -292,6 +293,12 @@ struct DeviceElementwise3dImpl : public DeviceElementwise(p_arg); if(pArg == nullptr) diff --git a/profiler/src/profile_transpose.cpp b/profiler/src/profile_transpose.cpp new file mode 100644 index 0000000000..c239a520d1 --- /dev/null +++ b/profiler/src/profile_transpose.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_transpose_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct MatrixLayout +{ + NCDHW, // 0 + NCHWD, // 1 +}; + +enum struct DataType +{ + F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F16_F16, // 1 +}; + +#define OP_NAME "transpose" +#define OP_DESC "Transpose" + +int profile_transpose(int argc, char* argv[]) +{ + if(argc != 15) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + // printf("arg3: matrix layout (NCDHW -> NDCHW);\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 13: N, C, D, H, W\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + // const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + std::vector lengths = std::stoi(argv[7]); + + /**const int N = std::stoi(argv[7]); + const int C = std::stoi(argv[8]); + const int D = std::stoi(argv[9]); + const int H = std::stoi(argv[10]); + const int W = std::stoi(argv[11]);**/ + + using F32 = float; + using F16 = ck::half_t; + + auto profile = [&](auto a_type, auto b_type) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + + bool pass = ck::profiler::profile_transpose_impl( + do_verification, init_method, do_log, time_kernel, lengths); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F32_F32_F32_F32_F32) + { + return profile(F32{}, F32{}); + } + else if(data_type == GemmDataType::F16_F16_F16_F16_F16) + { + return profile(F16{}, F16{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_transpose); From 8ff845f2c4aa7bbdd728b7639a79e0b6932c6dab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Thu, 30 Nov 2023 12:11:43 +0100 Subject: [PATCH 02/75] Introduce wrapper for layout (#1054) * Introduce wrapper for layout * Extend functionality * Fix for getLength * Comment fixes * Add comments and remove not needed getters --- example/64_tensor_transforms/CMakeLists.txt | 2 + .../64_tensor_transforms/tensor_transform.cpp | 150 +++++++ .../tensor_transform_using_wrapper.cpp | 119 +++++ .../tensor_transform_wrapper.hpp | 425 ++++++++++++++++++ include/ck/utility/tuple_helper.hpp | 88 ++++ 5 files changed, 784 insertions(+) create mode 100644 example/64_tensor_transforms/CMakeLists.txt create mode 100644 example/64_tensor_transforms/tensor_transform.cpp create mode 100644 example/64_tensor_transforms/tensor_transform_using_wrapper.cpp create mode 100644 example/64_tensor_transforms/tensor_transform_wrapper.hpp diff --git a/example/64_tensor_transforms/CMakeLists.txt b/example/64_tensor_transforms/CMakeLists.txt new file mode 100644 index 0000000000..9d14a410e3 --- /dev/null +++ b/example/64_tensor_transforms/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_tensor_transform tensor_transform.cpp) +add_example_executable(example_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp) diff --git a/example/64_tensor_transforms/tensor_transform.cpp b/example/64_tensor_transforms/tensor_transform.cpp new file mode 100644 index 0000000000..41ceec1cb5 --- /dev/null +++ b/example/64_tensor_transforms/tensor_transform.cpp @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" + +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/sequence.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" + +static constexpr auto I0 = ck::Number<0>{}; +static constexpr auto I1 = ck::Number<1>{}; +static constexpr auto I2 = ck::Number<2>{}; + +using DataType = int; + +template +void Print1d(const Desc& desc) +{ + std::cout << "Print1d" << std::endl; + for(ck::index_t w = 0; w < desc.GetLength(I0); w++) + { + std::cout << desc.CalculateOffset(ck::make_tuple(w)) << " "; + } + std::cout << std::endl; +} + +template +void Print2d(const Desc& desc) +{ + std::cout << "Print2d" << std::endl; + for(ck::index_t h = 0; h < desc.GetLength(I0); h++) + { + for(ck::index_t w = 0; w < desc.GetLength(I1); w++) + { + std::cout << desc.CalculateOffset(ck::make_tuple(h, w)) << " "; + } + std::cout << std::endl; + } +} + +template +void Print3dCustom(const Desc& desc) +{ + std::cout << "Print3dCustom" << std::endl; + for(ck::index_t d = 0; d < desc.GetLength(I0); d++) + { + for(ck::index_t h = 0; h < desc.GetLength(I1); h++) + { + for(ck::index_t w = 0; w < desc.GetLength(I2); w++) + { + std::cout << desc.CalculateOffset(ck::make_tuple(d, h, w)) << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; + } +} + +int main() +{ + // Tensor descriptor traverse in row-major (need to reverse dims) + std::cout << "Note: Tensor descriptor traverse in row-major" << std::endl; + // Basic descriptor 0, 1, 2, ... 30, 31 + // (dims:4,8 strides:1,4) + const auto desc_4x8_s1x4 = + ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}), + ck::make_tuple(ck::Number<1>{}, ck::Number<4>{})); + std::cout << "dims:4,8 strides:1,4" << std::endl; + Print2d(desc_4x8_s1x4); + + using Cord1x1Type = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t offset_1x1 = desc_4x8_s1x4.CalculateOffset(Cord1x1Type{}); + std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl; + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) + // dims:4,(2,4) strides:2,(1,8) + const auto desc_4x2x4_s2x1x8 = + ck::make_naive_tensor_descriptor(ck::make_tuple(4, 2, 4), ck::make_tuple(2, 1, 8)); + // Transform to 2d (column-major, need to to reverse dims) + const auto desc_4x2x4_s2x1x8_merged = ck::transform_tensor_descriptor( + desc_4x2x4_s2x1x8, + ck::make_tuple(ck::make_pass_through_transform(4), + ck::make_merge_transform(ck::make_tuple(4, 2))), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<2, 1>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); + + std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; + Print2d(desc_4x2x4_s2x1x8_merged); + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) + // dims:(2,2),(2,4) strides:((1,4),(2,8) + const auto desc_2x2x2x4_s1x4x2x8 = + ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8)); + // Transform to 2d + const auto desc_2x2x2x4_s1x4x2x8_double_merged_2d = ck::transform_tensor_descriptor( + desc_2x2x2x4_s1x4x2x8, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)), + ck::make_merge_transform(ck::make_tuple(4, 2))), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<3, 2>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); + // Transform to 3d + const auto desc_2x2x2x4_s1x4x2x8_double_merged_3d = ck::transform_tensor_descriptor( + desc_2x2x2x4_s1x4x2x8, + ck::make_tuple(ck::make_pass_through_transform(2), + ck::make_pass_through_transform(2), + ck::make_merge_transform(ck::make_tuple(4, 2))), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<3, 2>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{})); + + std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl; + Print2d(desc_2x2x2x4_s1x4x2x8_double_merged_2d); + Print3dCustom(desc_2x2x2x4_s1x4x2x8_double_merged_3d); + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) + // dims:((2,2),2),4 strides:((1,4),2),8 + // Transform to 2d + const auto desc_2x2x2x4_s1x4x2x8_nested = + ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8)); + const auto desc_2x2x2x4_s1x4x2x8_nested_merged_3d = ck::transform_tensor_descriptor( + desc_2x2x2x4_s1x4x2x8_nested, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)), + ck::make_pass_through_transform(2), + ck::make_pass_through_transform(4)), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<3>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{})); + const auto desc_2x2x2x4_s1x4x2x8_nested_merged_1d = ck::transform_tensor_descriptor( + desc_2x2x2x4_s1x4x2x8_nested, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(4, 2, 2, 2))), + ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}), + ck::make_tuple(ck::Sequence<0>{})); + const auto desc_2x2x2x4_s1x4x2x8_nested_merged_2d = ck::transform_tensor_descriptor( + desc_2x2x2x4_s1x4x2x8_nested_merged_3d, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 4)), + ck::make_pass_through_transform(4)), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); + + std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl; + Print1d(desc_2x2x2x4_s1x4x2x8_nested_merged_1d); + Print2d(desc_2x2x2x4_s1x4x2x8_nested_merged_2d); + Print3dCustom(desc_2x2x2x4_s1x4x2x8_nested_merged_3d); + + return 0; +} diff --git a/example/64_tensor_transforms/tensor_transform_using_wrapper.cpp b/example/64_tensor_transforms/tensor_transform_using_wrapper.cpp new file mode 100644 index 0000000000..df2449e99d --- /dev/null +++ b/example/64_tensor_transforms/tensor_transform_using_wrapper.cpp @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" + +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/sequence.hpp" + +#include "tensor_transform_wrapper.hpp" + +using DataType = int; + +template +void Print1d(const Layout& layout) +{ + std::cout << "Print1d" << std::endl; + for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size(layout); w++) + { + std::cout << layout(ck::make_tuple(w)) << " "; + } + std::cout << std::endl; +} + +template +void Print2d(const Layout& layout) +{ + std::cout << "Print2d" << std::endl; + for(ck::index_t h = 0; h < ck::tensor_transform_wrapper::size<0>(layout); h++) + { + for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++) + { + std::cout << layout(ck::make_tuple(h, w)) << " "; + } + std::cout << std::endl; + } +} + +// Print in (x,y),z pattern +template +void Print3dCustom(const Layout& layout) +{ + std::cout << "Print3dCustom" << std::endl; + for(ck::index_t d = 0; + d < ck::tensor_transform_wrapper::size<0>(ck::tensor_transform_wrapper::get<0>(layout)); + d++) + { + for(ck::index_t h = 0; + h < ck::tensor_transform_wrapper::size<1>(ck::tensor_transform_wrapper::get<0>(layout)); + h++) + { + for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++) + { + std::cout << layout(ck::make_tuple(ck::make_tuple(d, h), w)) << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; + } +} + +int main() +{ + // Layout traverse in row-major + std::cout << "Note: Layout traverse in column-major" << std::endl; + // Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor) + // (dims:4,8 strides:1,4) + const auto shape_4x8 = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}); + const auto layout_4x8_s1x4 = ck::tensor_transform_wrapper::make_layout(shape_4x8); + std::cout << "dims:4,8 strides:1,4" << std::endl; + Print2d(layout_4x8_s1x4); + using Cord1x1Type = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t offset_1x1 = layout_4x8_s1x4.template operator()(); + std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl; + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor) + // dims:4,(2,4) strides:2,(1,8) + const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); + const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); + const auto layout_4x2x4_s2x1x8 = + ck::tensor_transform_wrapper::make_layout(shape_4x2x4, strides_s2x1x8); + + std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; + Print2d(layout_4x2x4_s2x1x8); + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) + // dims:(2,2),(2,4) strides:((1,4),(2,8) + const auto shape_2x2x2x4 = ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), + ck::make_tuple(ck::Number<2>{}, ck::Number<4>{})); + const auto strides_s1x4x2x8 = ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), + ck::make_tuple(ck::Number<2>{}, ck::Number<8>{})); + static const auto layout_2x2x2x4_s1x4x2x8 = + ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8); + + std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl; + Print2d(layout_2x2x2x4_s1x4x2x8); + Print3dCustom(layout_2x2x2x4_s1x4x2x8); + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) + // dims:((2,2),2),4 strides:((1,4),2),8 + // Transform to 2d + const auto shape_2x2x2x4_nested = ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<2>{}), + ck::Number<4>{}); + const auto strides_s1x4x2x8_nested = ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::Number<2>{}), + ck::Number<8>{}); + static const auto layout_2x2x2x4_s1x4x2x8_nested = + ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested); + + std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl; + Print1d(layout_2x2x2x4_s1x4x2x8_nested); + Print2d(layout_2x2x2x4_s1x4x2x8_nested); + Print3dCustom(layout_2x2x2x4_s1x4x2x8_nested); + + return 0; +} diff --git a/example/64_tensor_transforms/tensor_transform_wrapper.hpp b/example/64_tensor_transforms/tensor_transform_wrapper.hpp new file mode 100644 index 0000000000..71cd6091f8 --- /dev/null +++ b/example/64_tensor_transforms/tensor_transform_wrapper.hpp @@ -0,0 +1,425 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/tuple_helper.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/utility/is_detected.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" + +namespace ck { +namespace tensor_transform_wrapper { + +/** + * \brief Layout wrapper + * + * \details + * Layout wrapper that performs the tensor descriptor logic. + * + * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t + * (dynamic layout). It is possible to pass nested shapes + * (e.g. ((4, 2), 2)), nested dimensions are merged. + * \tparam Strides Tuple of Number<> (for compile-time layout) or index_t + * (dynamic layout). Stride tuple should be nested if shape tuple is + * nested. + */ +template > +struct Layout +{ + private: + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + template + using is_tuple = decltype(std::declval().IsTuple()); + + // Generate packed (column-major) strides if not passed + template + __host__ __device__ constexpr static auto + GenerateColumnMajorPackedStrides(const Tuple& tuple) + { + return generate_tuple( + [&](auto i) { + if constexpr(i.value == 0) + { + return I1; + } + else + { + return TupleReduce([](auto x, auto y) { return x * y; }, + tuple); + } + }, + Number::Size()>{}); + } + + // Generate LowerDims in Compile-time for MergeTrasform using passed Type + // If element of Tuple is also tuple, then merge (generate sequence for merge) + // If tuple is element, then pass through (sequence with one element) + template + __host__ __device__ constexpr static auto GenerateLowerDim(const Tuple&) + { + if constexpr(Idx::value == 0) + { + if constexpr(is_detected>>::value) + { + // Return Sequence for the first tuple + constexpr index_t merge_nelems = decltype(UnrollNestedTuple( + tuple_element_t>{}))::Size(); + using LowerDimsSequence = + typename arithmetic_sequence_gen<0, merge_nelems, 1>::type; + return LowerDimsSequence::Reverse(); + } + else + { + // Return first element + return Sequence<0>{}; + } + } + else + { + // Get previous element using recurence (in compile-time) + using PreviousSeqT = decltype(GenerateLowerDim>(Tuple{})); + const auto next_seq_val = PreviousSeqT::At(I0) + 1; + if constexpr(is_detected>>::value) + { + constexpr index_t merge_nelems = decltype(UnrollNestedTuple( + tuple_element_t>{}))::Size(); + using LowerDimsSequence = + typename arithmetic_sequence_gen:: + type; + return LowerDimsSequence::Reverse(); + } + else + { + return Sequence{}; + } + } + } + + // Iterate over nested tuples in shape + // Unroll nested tuples to align Tuple to Tuple + // Example idx: (1, 1), 1, 1 + // Example shape: (2, (2, 2)), 2, (2, 2) + // Unrolled shape: 2, (2, 2), 2, (2, 2) + template + __host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple& shape, + const Tuple& idx) + { + if constexpr(!IsNestedTuple(Tuple{})) + { + // Index unrolled to flatten, return shape + return shape; + } + else + { + // Iterate over shape tuple elements: + // 1. If corresponding idx element is tuple then return (will be unrolled) + // 2. If no, pack in tuple. It will be restored during unroll. + auto unrolled_shape_via_idx = generate_tuple( + [&](auto i) { + if constexpr(is_detected>>::value) + { + return shape.At(i); + } + else + { + return make_tuple(shape.At(i)); + } + }, + Number::Size()>{}); + + // Unroll and process next step + return UnrollShapeViaIdx(UnrollNestedTuple<0, 1>(unrolled_shape_via_idx), + UnrollNestedTuple<0, 1>(idx)); + } + } + + template + __host__ __device__ constexpr static auto MakeMerge1d(const Tuple& shape, + DescriptorToMerge& desc) + { + // Reverse each element in tuple + using ReversedUnrolledShape = decltype(TupleReverse(UnrollNestedTuple(shape))); + const auto merge_elems = ReversedUnrolledShape{}; + + // Generate reverted indexes (column major traverse) + using MergeElemsSequence = + typename arithmetic_sequence_gen<0, ReversedUnrolledShape::Size(), 1>::type; + const auto lower_dims = make_tuple(MergeElemsSequence::Reverse()); + const auto upper_dims = make_tuple(Sequence<0>{}); + // Merge to 1d + return transform_tensor_descriptor( + desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims); + } + + // Merge nested shape dims + // Input desc shape: 2, 2, 2, 2, 2, 2 + // Example idx: 1, 1, 1, 1 + // Example shape: 2, (2, 2), 2, (2, 2) + // Merged shape: 2, 4, 2, 4 + template + __host__ __device__ constexpr static auto + MakeMerges(const Tuple& shape, const Tuple&, DescriptorToMerge& desc) + { + const auto transforms = generate_tuple( + [&](auto i) { + // Compare Idx with shape + if constexpr(is_detected>>::value && + !is_detected>>::value) + { + // If shape element is tuple and idx element is Number, then merge + // Unroll and reverse tuple to traverse column-major + const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i))); + return make_merge_transform(merge_elems); + } + else + { + // If shape element is integer and idx element is tuple, passed idx is wrong + static_assert( + !(!is_detected>>::value && + is_detected>>::value), + "Wrong Idx for layout()"); + // If shape element has the same type as idx element, then pass through + return make_pass_through_transform(shape.At(i)); + } + }, + Number::Size()>{}); + + const auto lower_dims = + generate_tuple([&](auto i) { return GenerateLowerDim>(shape); }, + Number::Size()>{}); + const auto upper_dims = generate_tuple([&](auto i) { return Sequence{}; }, + Number::Size()>{}); + + return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims); + } + + template + __host__ __device__ constexpr auto TransformDesc(const Tuple& shape, + const Tuple& idx) const + { + if constexpr(Tuple::Size() == I1) + { + // 1d idx path + return MakeMerge1d(shape, descriptor_); + } + else + { + // Merge nested shape dims + // Example idx: (1, 1), 1, 1 + // Example shape: (2, (2, 2)), 2, (2, 2) + // Merged shape: (2, 4), 2, 4 + static_assert(Tuple::Size() == Tuple::Size(), + "Idx rank and Shape rank must be the same (except 1d)."); + // Unroll while IdxDims is nested + const auto unrolled_shape_via_idx = UnrollShapeViaIdx(shape, idx); + // Transform correct form of shape + return MakeMerges(unrolled_shape_via_idx, UnrollNestedTuple(idx), descriptor_); + } + } + + template + __host__ __device__ static auto MakeNaiveDescriptor(const LayoutShape& shape, + const LayoutStrides& strides) + { + const auto unrolled_shape = UnrollNestedTuple(shape); + + if constexpr(ck::is_same_v>) + { + // If shape is packed + const auto column_major_packed_strides = + GenerateColumnMajorPackedStrides(unrolled_shape); + return make_naive_tensor_descriptor(unrolled_shape, column_major_packed_strides); + } + else + { + const auto unrolled_strides = UnrollNestedTuple(strides); + static_assert(unrolled_shape.Size() == unrolled_strides.Size(), + "Size of strides and shape are not consistent."); + return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); + } + } + + public: + using NaiveDescriptorType = remove_cvref_t; + + /** + * \brief Layout constructor. + * + * \param shape Shape for layout. + * \param strides Strides for layout (optional if tensor is packed). + * \return Layout object. + */ + __host__ __device__ Layout() = delete; + __host__ __device__ Layout(const Shape& shape, const Strides& strides) : descriptor_{} + { + // Construct if runtime mode + if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime()) + { + // Keep only shape, strides are not need for transforms + shape_ = shape; + descriptor_ = MakeNaiveDescriptor(shape, strides); + } + } + + __host__ __device__ Layout(const Shape& shape) : descriptor_{} + { + if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime()) + { + shape_ = shape; + descriptor_ = MakeNaiveDescriptor(shape, Strides{}); + } + } + + /** + * \brief Returns real offset to element in runtime. + * + * \tparam Idxs Tuple of indexes. + * \return Calculated offset. + */ + template + __host__ __device__ constexpr index_t operator()() const + { + using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{})); + using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{})); + return TransformedDesc{}.CalculateOffset(UnrolledIdx{}); + } + + /** + * \brief Returns real offset to element in compile time. + * + * \param Idx Tuple of indexes. + * \return Calculated offset. + */ + template + __host__ __device__ index_t operator()(const Tuple& Idx) const + { + // Static to construct transformed_desc only once + static const auto transformed_desc = TransformDesc(shape_, Idx); + return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx)); + } + + /** + * \brief Length getter (product if tuple). + * + * \tparam IDim Tuple of indexes or index. + * \return Calculated size. + */ + template + __host__ __device__ constexpr index_t GetLength() const + { + const auto elem = shape_.At(Number{}); + if constexpr(is_detected>::value) + { + const auto unrolled_element = UnrollNestedTuple(elem); + return TupleReduce( + [](auto x, auto y) { return x * y; }, unrolled_element); + } + else + { + return elem; + } + } + + /** + * \brief Layout size getter (product of shape). + * + * \return Calculated size. + */ + __host__ __device__ constexpr index_t GetLength() const + { + const auto unrolled_shape = UnrollNestedTuple(shape_); + return TupleReduce([](auto x, auto y) { return x * y; }, + unrolled_shape); + } + + /** + * \brief Dimension getter. + * + * \tparam IDim Dimension idx. + * \return Calculated size. + */ + template + __host__ __device__ constexpr auto Get() const + { + const auto elem = shape_.At(Number{}); + return elem; + } + + private: + NaiveDescriptorType descriptor_; + Shape shape_; +}; + +// Layout helpers +// Length getter (product if tuple) +template +__host__ __device__ constexpr index_t size(const Layout& layout) +{ + return layout.template GetLength(); +} + +// Get shape size (product of dims if tuple) +template +__host__ __device__ constexpr index_t size(const Tuple& shape) +{ + using UnrolledShape = decltype(UnrollNestedTuple(shape)); + return TupleReduce<0, UnrolledShape::Size()>([](auto x, auto y) { return x * y; }, + UnrolledShape{}); +} + +// Get dim size (could be returned from get function) +template +__host__ __device__ T constexpr size(const T& dim) +{ + return dim; +} + +// Get layout size (product of shapes) +template +__host__ __device__ constexpr index_t size(const Layout& layout) +{ + return layout.GetLength(); +} + +// Get shape element size +template +__host__ __device__ constexpr index_t size(const Tuple& shape) +{ + return size(shape.At(Number{})); +} + +// Dim getter (tuple if tuple) +template +__host__ __device__ constexpr auto get(const Layout& layout) +{ + return layout.template Get(); +} + +template +__host__ __device__ constexpr Layout make_layout(const Shape& shape, + const Strides& strides) +{ + return Layout(shape, strides); +} + +template +__host__ __device__ constexpr Layout make_layout(const Shape& shape) +{ + return Layout(shape); +} + +} // namespace tensor_transform_wrapper +} // namespace ck diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index e39ae1c23d..d7b492fe66 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -5,6 +5,7 @@ #include "functional4.hpp" #include "tuple.hpp" +#include "is_detected.hpp" namespace ck { @@ -33,6 +34,28 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple& ty); } +template +__host__ __device__ constexpr auto concat_tuple(const Tuple& tx, const Tuple& ty) +{ + return unpack2( + [&](auto... zs) { return Tuple{std::forward(zs)...}; }, + tx, + ty); +} + +// Support any number of tuples to concat (also 1) +template +__host__ __device__ constexpr auto concat_tuple(const Tuple& tx) +{ + return tx; +} + +template +__host__ __device__ constexpr auto concat_tuple(const Tuple& tx, const Tuples&... tuples) +{ + return concat_tuple(tx, concat_tuple(tuples...)); +} + namespace detail { template @@ -78,4 +101,69 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); } +// By default unroll to the flatten +template +__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<>& element) +{ + return element; +} + +template +__host__ __device__ constexpr auto UnrollNestedTuple(const T& element) +{ + return make_tuple(element); +} + +template +__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple& tuple) +{ + if constexpr(Depth == MaxDepth) + { + return tuple; + } + else + { + return unpack( + [&](auto&&... ts) { + return concat_tuple(UnrollNestedTuple(ts)...); + }, + tuple); + } +} + +template +__host__ __device__ constexpr auto TupleReverse(const Tuple& tuple) +{ + return generate_tuple( + [&](auto i) { + using Idx = Number::Size() - i - 1>; + return tuple.At(Idx{}); + }, + Number::Size()>{}); +} + +// Reduce tuple values in specific range using Function +template +__host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple& tuple) +{ + static_assert(Idx < End, "Wrong parameters for TupleReduce"); + if constexpr(Idx + 1 == End) + { + return tuple.At(Number{}); + } + else + { + return f(tuple.At(Number{}), TupleReduce(f, tuple)); + } +} + +template +using is_tuple = decltype(std::declval().IsTuple()); + +template +__host__ __device__ constexpr auto IsNestedTuple(const Tuple&) +{ + return (is_detected::value || ...); +} + } // namespace ck From 49df1dc595734d20ecdf9dfe11933e527fea84f1 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Thu, 30 Nov 2023 15:09:27 -0600 Subject: [PATCH 03/75] Fixed GroupedGemmFixedNK with hipGraph (#1065) * fixed examples; add async_mem_set * add stream to all deviceOp using SetWorkspace --------- Co-authored-by: Jing Zhang --- example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp | 4 ++-- example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp | 4 ++-- include/ck/tensor_operation/gpu/device/device_base.hpp | 4 +++- .../gpu/device/impl/device_batchnorm_backward_impl.hpp | 4 +++- .../gpu/device/impl/device_batchnorm_forward_impl.hpp | 4 +++- .../device/impl/device_batchnorm_forward_impl_obsolete.hpp | 4 +++- .../impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp | 4 +++- .../gpu/device/impl/device_gemm_xdl_streamk.hpp | 4 +++- .../gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp | 7 +++++-- .../device/impl/device_normalization_fwd_splitk_impl.hpp | 4 +++- 10 files changed, 30 insertions(+), 13 deletions(-) diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 95b8526094..2c1feafce3 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -299,8 +299,8 @@ int main(int argc, char* argv[]) for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); - problem_size.Ns.push_back(128 + 128 * i); - problem_size.Ks.push_back(128 + 64 * i); + problem_size.Ns.push_back(256); + problem_size.Ks.push_back(128); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp index 84abe1d1db..9fd63cba77 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp @@ -300,8 +300,8 @@ int main(int argc, char* argv[]) for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); - problem_size.Ns.push_back(128 + 128 * i); - problem_size.Ks.push_back(128 + 64 * i); + problem_size.Ns.push_back(256); + problem_size.Ks.push_back(128); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 1981690111..908ada016d 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -59,7 +59,9 @@ struct BaseOperator virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } - virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const + virtual void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const { assert(p_arg); p_arg->p_workspace_ = p_workspace; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp index f46237e005..3b62cf10a3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp @@ -376,7 +376,9 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp index ad8e795603..e7e4668d92 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp @@ -354,7 +354,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp index b826793c27..c3e0837722 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp @@ -345,7 +345,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp index b0efa9d4e4..f7319226a9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp @@ -821,7 +821,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle return (workspace_size); }; - void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override + void SetWorkSpacePointer(BaseArgument* pArg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override { Argument* pArg_ = dynamic_cast(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp index 8de42ba9ef..c8799e5154 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp @@ -226,7 +226,9 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index 56132f7a0f..0a0cb59063 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -817,12 +817,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK); } - void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const override + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& stream_config = StreamConfig{}) const override { auto p_arg_ = dynamic_cast(p_arg); p_arg_->p_workspace_ = p_workspace; - hip_check_error(hipMemset(p_workspace, 0, GetWorkSpaceSize(p_arg))); + hip_check_error( + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_)); } static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp index 58db34c9f2..6a117920f4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp @@ -577,7 +577,9 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd(pArg); From c7d5c7727b9f2fcb7141a57761ad91134bbb6317 Mon Sep 17 00:00:00 2001 From: Jun Liu Date: Thu, 30 Nov 2023 15:24:59 -0800 Subject: [PATCH 04/75] [CI] Update Jenkinsfile (#1073) --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index 505232de47..91499e7eb8 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1,5 +1,5 @@ def rocmnode(name) { - return 'rocmtest && miopen && ' + name + return '(rocmtest || miopen) && ' + name } def show_node_info() { From bc4bf9bd03a74aba1860b80fbeb85fb1f47b8b19 Mon Sep 17 00:00:00 2001 From: Bartlomiej Wroblewski Date: Sun, 3 Dec 2023 23:08:47 +0100 Subject: [PATCH 05/75] Add support for double buffering in direct load GEMM kernel (#1052) This PR introduces support for double buffering in LDS into GEMM kernels that use direct load instructions. Direct loads now use inline asm instead of intrinsics. Usage of intrinsics results in compiler adding additional waitcnt instructions what breaks possible load/compute overlap in case of double buffering. Usage of inline asm results in the need to use sched_barrier in order to make sure that compiler cannot incorrectly reschedule instructions since it does not know the data dependencies between global->LDS and LDS->registers. --- include/ck/ck.hpp | 3 + ...vice_gemm_xdl_cshuffle_lds_direct_load.hpp | 4 +- ...ultiple_d_xdl_cshuffle_lds_direct_load.hpp | 41 +++-- .../gridwise_gemm_pipeline_v4_direct_load.hpp | 147 +++++++++++++++++- include/ck/utility/amd_buffer_addressing.hpp | 10 ++ ...ect_load_f16_f16_f16_mk_nk_mn_instance.cpp | 16 +- ...ect_load_f32_f32_f32_km_kn_mn_instance.cpp | 3 +- ...ect_load_f32_f32_f32_km_nk_mn_instance.cpp | 3 +- ...ect_load_f32_f32_f32_mk_kn_mn_instance.cpp | 3 +- ...ect_load_f32_f32_f32_mk_nk_mn_instance.cpp | 5 +- 10 files changed, 211 insertions(+), 24 deletions(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 4a2b5c0ad7..a94057be4a 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -134,6 +134,9 @@ // inner product using V_DOT with DPP8 modifiers #define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1 +// LDS direct loads using inline assembly +#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1 + // set stochastic rounding as default for f8 conversions #define CK_USE_SR_F8_CONVERSION 1 diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp index f8264cefd3..ac2e826725 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp @@ -380,7 +380,9 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm + __device__ static auto AllocateBlockBuffers(void* p_shared, + int32_t num_elems, + int32_t offset_elems, + int32_t max_lds_align) + { + const int32_t single_buffer_offset = math::integer_least_multiple(num_elems, max_lds_align); + return generate_tuple( + [&](auto i) { + const int32_t local_offset = i * single_buffer_offset; + return make_dynamic_buffer( + static_cast(p_shared) + local_offset + offset_elems, num_elems); + }, + Number{}); + } + template ( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); - - auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, - b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto a_block_buffers = AllocateBlockBuffers( + p_shared, a_block_desc_ak0_m_ak1.GetElementSpaceSize(), 0, max_lds_align); + const auto b_buffers_offset = a_block_space_size_aligned * NumGemmKPrefetchStage; + auto b_block_buffers = + AllocateBlockBuffers(p_shared, + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), + b_buffers_offset, + max_lds_align); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); @@ -645,13 +664,13 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad a_block_desc_ak0_m_ak1, a_blockwise_copy, a_grid_buf, - a_block_buf, + a_block_buffers, a_block_slice_copy_step, b_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_blockwise_copy, b_grid_buf, - b_block_buf, + b_block_buffers, b_block_slice_copy_step, blockwise_gemm, c_thread_buf, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp index 1c59f37a9e..08d986d0da 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp @@ -7,6 +7,20 @@ #include "ck/utility/loop_scheduler.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +namespace lds_direct_load { + +__device__ void sched_barrier() +{ +#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM + // When direct loads and `waitcnt` instructions are submitted using inline asm, the usage of + // `sched_barrier` is necessary to make sure no instructions that use the loaded memory + // are scheduled by the compiler before the `waitcnt` instruction. + __builtin_amdgcn_sched_barrier(0); +#endif +} + +} // namespace lds_direct_load + namespace ck { template @@ -17,7 +31,6 @@ template <> struct GridwiseGemmPipeline_v4<1> { static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } @@ -31,13 +44,13 @@ struct GridwiseGemmPipeline_v4<1> typename ABlockDesc, typename ABlockTransfer, typename AGridBuffer, - typename ABlockBuffer, + typename ABlockBuffers, typename ABlockTransferStep, typename BGridDesc, typename BBlockDesc, typename BBlockTransfer, typename BGridBuffer, - typename BBlockBuffer, + typename BBlockBuffers, typename BBlockTransferStep, typename BlockwiseGemm, typename CThreadBuffer> @@ -45,18 +58,22 @@ struct GridwiseGemmPipeline_v4<1> const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, const AGridBuffer& a_grid_buf, - ABlockBuffer& a_block_buf, + ABlockBuffers& a_block_bufs, const ABlockTransferStep& a_block_copy_step, const BGridDesc& b_grid_desc, const BBlockDesc& b_block_desc, BBlockTransfer& b_blockwise_copy, const BGridBuffer& b_grid_buf, - BBlockBuffer& b_block_buf, + BBlockBuffers& b_block_bufs, const BBlockTransferStep& b_block_copy_step, const BlockwiseGemm& blockwise_gemm, CThreadBuffer& c_thread_buf, index_t num_loop) { + static_assert(ABlockBuffers::Size() == 1 && BBlockBuffers::Size() == 1); + auto& a_block_buf = a_block_bufs.At(I0); + auto& b_block_buf = b_block_bufs.At(I0); + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); @@ -74,10 +91,12 @@ struct GridwiseGemmPipeline_v4<1> do { block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); @@ -92,10 +111,128 @@ struct GridwiseGemmPipeline_v4<1> // tail { block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); } } }; +// 2-stages prefetch +template <> +struct GridwiseGemmPipeline_v4<2> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + { + return num_loop % 2 == 0; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return (num_loop / 2) > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffers& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffers& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + static_assert(ABlockBuffers::Size() == 2 && BBlockBuffers::Size() == 2); + auto& a_block_buf1 = a_block_bufs.At(I0); + auto& a_block_buf2 = a_block_bufs.At(I1); + auto& b_block_buf1 = b_block_bufs.At(I0); + auto& b_block_buf2 = b_block_bufs.At(I1); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf); + + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf); + + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf); + } + } +}; + } // namespace ck diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index ef3874ba3a..2ea5419d09 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -972,6 +972,15 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; +#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM + T* lds_ptr = lds_base_ptr + lds_offset; + auto const lds_ptr_sgpr = + __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), + "v"(global_offset_bytes), + "s"(src_resource)); +#else // LDS pointer must be attributed with the LDS address space. __attribute__((address_space(3))) uint32_t* lds_ptr = reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( @@ -979,6 +988,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, llvm_amdgcn_raw_buffer_load_lds( src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); +#endif } } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp index 9c96e12c32..bb40237bf9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp @@ -35,7 +35,21 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances = // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 64, 8, 8, 32, 32, 1, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 32, 128, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 32, 64, 8, 8, 32, 32, 1, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp index fcfd766b04..94f75d0e0f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instance.cpp @@ -32,7 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances = // ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp index 68c0488803..0f4ebc350b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instance.cpp @@ -32,7 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances = // ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp index ef09478d1c..d2bc9351b6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instance.cpp @@ -31,7 +31,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances = // ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp index aec5421627..2c208c01f3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instance.cpp @@ -24,8 +24,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances = std::tuple< // clang-format off @@ -34,7 +33,7 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances = // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; From afe4622014ab7f1d2e74743ed0e39ae63b13410c Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 4 Dec 2023 19:04:52 -0800 Subject: [PATCH 06/75] Add daily run with mainline compiler. (#1075) * add daily build with mainline compiler * fix the compiler paths for ci * remove the -flto flag * build with clang by default --- CMakeLists.txt | 5 ++--- Dockerfile | 4 ++-- Jenkinsfile | 19 ++++++++++--------- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 04674124cc..e780c15657 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -373,10 +373,9 @@ include_directories(BEFORE SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") if(BUILD_DEV) - add_compile_options(-Werror -Weverything) + add_compile_options(-Werror) + add_compile_options(-Weverything) endif() -#add flags to reduce the size of binaries -add_compile_options(-Oz -flto=thin) message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) diff --git a/Dockerfile b/Dockerfile index 7134e206c1..87b4eb8e2b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -111,7 +111,7 @@ ENV compiler_commit=$compiler_commit RUN sh -c "echo compiler version = '$compiler_version'" RUN sh -c "echo compiler commit = '$compiler_commit'" -RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" = "" ]; then \ +RUN if [ "$compiler_version" != "" ] && [ "$compiler_commit" = "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ @@ -119,7 +119,7 @@ RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" = "" ]; else echo "using the release compiler"; \ fi -RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" != "" ]; then \ +RUN if [ "$compiler_version" != "" ] && [ "$compiler_commit" != "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ diff --git a/Jenkinsfile b/Jenkinsfile index 91499e7eb8..8e67f9cc39 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -84,7 +84,7 @@ def build_compiler(){ compiler = '/opt/rocm/bin/hipcc' } else{ - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION != "" || params.COMPILER_COMMIT != ""){ compiler = "/llvm-project/build/bin/clang++" } else{ @@ -293,7 +293,7 @@ def buildHipClangJob(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION != "" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -348,7 +348,7 @@ def runCKProfiler(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION != "" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -479,7 +479,7 @@ def Build_CK(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION != "" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -657,7 +657,8 @@ def process_results(Map conf=[:]){ //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION= 0 21 * * * % ROCMVERSION=5.7;COMPILER_VERSION=;COMPILER_COMMIT= - 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=;USE_SCCACHE=false''' : "" + 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=;USE_SCCACHE=false + 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false''' : "" pipeline { agent none @@ -679,15 +680,15 @@ pipeline { string( name: 'COMPILER_VERSION', defaultValue: '', - description: 'Specify which version of compiler to use: release, amd-stg-open, or leave blank (default).') + description: 'Specify which version of compiler to use: release, amd-stg-open, amd-mainline-open, or leave blank (default).') string( name: 'COMPILER_COMMIT', defaultValue: '', - description: 'Specify which commit of compiler branch to use: leave blank to use the latest commit, or use 5541927df00eabd6a110180170eca7785d436ee3 (default) commit of amd-stg-open branch.') + description: 'Specify which commit of compiler branch to use: leave blank to use the latest commit (default), or use some specific commit of llvm-project branch.') string( name: 'BUILD_COMPILER', - defaultValue: 'hipcc', - description: 'Specify whether to build CK with hipcc (default) or with clang.') + defaultValue: 'clang', + description: 'Specify whether to build CK with hipcc or with clang (default).') booleanParam( name: "RUN_FULL_QA", defaultValue: false, From ff24b537cb5412b4720a8923bbc090de6d020a3b Mon Sep 17 00:00:00 2001 From: Jun Liu Date: Mon, 4 Dec 2023 23:45:16 -0800 Subject: [PATCH 07/75] [SWDEV-435347] disable instances failed with mainlien compiler (#1077) --- ...rouped_convolution_forward_scaleadd_ab.hpp | 43 ++++++----- ..._ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 73 ++++++++++--------- 2 files changed, 60 insertions(+), 56 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp index 1bea403afa..348bcaef8a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp @@ -23,19 +23,20 @@ using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; #ifdef CK_ENABLE_BF16 // grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK -void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( - std::vector, - NDHWGK, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - ScaleAdd, - ScaleAdd, - PassThrough>>>& instances); +// TODO: Workaround for https://ontrack-internal.amd.com/browse/SWDEV-435347 +// void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( +// std::vector, +// NDHWGK, +// ck::Tuple, +// ck::Tuple, +// ck::Tuple<>, +// BF16, +// ScaleAdd, +// ScaleAdd, +// PassThrough>>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -151,13 +152,15 @@ struct DeviceOperationInstanceFactory> && - is_same_v> && - is_same_v && is_same_v) - { - add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( - op_ptrs); - } + // TODO: Workaround for https://ontrack-internal.amd.com/browse/SWDEV-435347 + // if constexpr(is_same_v> && + // is_same_v> && + // is_same_v && is_same_v) + // { + // add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + // op_ptrs); + // } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v> && diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index c7801f02ce..d5b9da86c1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -9,42 +9,43 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( - std::vector, - NDHWGK, - ck::Tuple, - ck::Tuple, - ck::Tuple<>, - BF16, - ScaleAdd, - ScaleAdd, - PassThrough>>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvFwd1x1S1P0>{}); -} +// TODO: Workaround for https://ontrack-internal.amd.com/browse/SWDEV-435347 +// void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( +// std::vector, +// NDHWGK, +// ck::Tuple, +// ck::Tuple, +// ck::Tuple<>, +// BF16, +// ScaleAdd, +// ScaleAdd, +// PassThrough>>>& instances) +// { +// add_device_operation_instances( +// instances, +// device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, +// NDHWGC, +// GKZYXC, +// NDHWGK, +// ConvFwdDefault>{}); +// add_device_operation_instances( +// instances, +// device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, +// NDHWGC, +// GKZYXC, +// NDHWGK, +// ConvFwd1x1P0>{}); +// add_device_operation_instances( +// instances, +// device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, +// NDHWGC, +// GKZYXC, +// NDHWGK, +// ConvFwd1x1S1P0>{}); +// } } // namespace instance } // namespace device From f60cd9d7a6911f30b412a6405f0041221bc64ea9 Mon Sep 17 00:00:00 2001 From: Sam Wu Date: Tue, 5 Dec 2023 11:05:55 -0700 Subject: [PATCH 08/75] Standardize documentation for ReadtheDocs (#1057) Relates to https://github.com/RadeonOpenCompute/rocm-docs-core/issues/330 --- .github/dependabot.yml | 6 ++++++ .gitignore | 1 - .readthedocs.yaml | 10 +++++----- docs/conf.py | 27 +++++++++++++++++++-------- docs/doxygen/Doxyfile | 2 +- docs/sphinx/_toc.yml.in | 6 +++--- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 6 ++---- 8 files changed, 37 insertions(+), 23 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 276690bd4f..0e0a252eb6 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -10,3 +10,9 @@ updates: open-pull-requests-limit: 10 schedule: interval: "daily" + labels: + - "documentation" + - "dependencies" + - "ci:docs-only" + reviewers: + - "samjwu" diff --git a/.gitignore b/.gitignore index 7af066c82d..340f11cbd2 100644 --- a/.gitignore +++ b/.gitignore @@ -54,5 +54,4 @@ _images/ _static/ _templates/ _toc.yml -docBin/ _doxygen/ diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 5f50df2525..9e6678abe5 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -3,11 +3,6 @@ version: 2 -build: - os: ubuntu-22.04 - tools: - python: "3.8" - sphinx: configuration: docs/conf.py @@ -16,3 +11,8 @@ formats: [htmlzip, pdf, epub] python: install: - requirements: docs/sphinx/requirements.txt + +build: + os: ubuntu-22.04 + tools: + python: "3.8" diff --git a/docs/conf.py b/docs/conf.py index 0de590da1a..e441ff1ced 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,23 +4,34 @@ # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html -import subprocess +import re from rocm_docs import ROCmDocs +html_theme_options = {"flavor": "list"} -name = "Composable Kernel" -get_version = r'sed -n -e "s/^rocm_setup_version(.* \([0-9\.]\{1,\}\).*/\1/p" ../CMakeLists.txt' -version = subprocess.getoutput(get_version) -if len(version) > 0: - name = f"{name} {version}" +with open('../CMakeLists.txt', encoding='utf-8') as f: + match = re.search(r'.*set\(version ([0-9.]+)[^0-9.]+', f.read()) + if not match: + raise ValueError("VERSION not found!") + version_number = match[1] +left_nav_title = f"Composable Kernel {version_number} Documentation" + +# for PDF output on Read the Docs +project = "Composable Kernel Documentation" +author = "Advanced Micro Devices, Inc." +copyright = "Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved." +version = version_number +release = version_number external_toc_path = "./sphinx/_toc.yml" -docs_core = ROCmDocs(f"{name} Documentation") -docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/docBin/xml") +docs_core = ROCmDocs(left_nav_title) +docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/xml") docs_core.setup() +external_projects_current_project = "composable_kernel" + mathjax3_config = { 'tex': { 'macros': { diff --git a/docs/doxygen/Doxyfile b/docs/doxygen/Doxyfile index 1084f94c81..2594422095 100644 --- a/docs/doxygen/Doxyfile +++ b/docs/doxygen/Doxyfile @@ -58,7 +58,7 @@ PROJECT_LOGO = # entered, it will be relative to the location where doxygen was started. If # left blank the current directory will be used. -OUTPUT_DIRECTORY = docBin +OUTPUT_DIRECTORY = . # If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub- # directories (in 2 levels) under the output directory of each output format and diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index 83dd1e7b1a..c37ba29cec 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -5,6 +5,6 @@ defaults: maxdepth: 6 root: index subtrees: - - caption: About - entries: - - file: license +- caption: About + entries: + - file: license diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index c4ce8be79a..f5ee431e7d 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core>=0.20.0 +rocm-docs-core==0.29.0 sphinxcontrib-bibtex==2.6.1 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 5852315958..0442ae9a2b 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -96,9 +96,7 @@ pygments==2.14.0 # pydata-sphinx-theme # sphinx pyjwt[crypto]==2.6.0 - # via - # pygithub - # pyjwt + # via pygithub pynacl==1.5.0 # via pygithub pytz==2023.3.post1 @@ -113,7 +111,7 @@ requests==2.28.2 # via # pygithub # sphinx -rocm-docs-core==0.27.0 +rocm-docs-core==0.29.0 # via -r requirements.in six==1.16.0 # via From 836b7e557d028cc2d7c6b341352253fd81003e54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 6 Dec 2023 11:58:59 +0100 Subject: [PATCH 09/75] Introduce wrapper library (#1071) * Introduce wrapper library * Update cmake files * Revert "Update cmake files" This reverts commit c27f88b56590c11a88e26d5d0df7aca51a08133d. * Fix comments --- CHANGELOG.md | 1 + .../25_tensor_transforms/CMakeLists.txt | 4 + .../tensor_transform.cpp | 0 .../tensor_transform_using_wrapper.cpp | 31 +- docs/doxygen/Doxyfile | 4 +- docs/index.rst | 2 + docs/wrapper.rst | 54 ++ example/64_tensor_transforms/CMakeLists.txt | 2 - include/ck/utility/tuple_helper.hpp | 12 + .../ck/wrapper/layout.hpp | 181 ++----- include/ck/wrapper/layout_utils.hpp | 321 ++++++++++++ test/CMakeLists.txt | 1 + test/wrapper/CMakeLists.txt | 2 + test/wrapper/test_layout.cpp | 481 ++++++++++++++++++ 14 files changed, 945 insertions(+), 151 deletions(-) create mode 100644 client_example/25_tensor_transforms/CMakeLists.txt rename {example/64_tensor_transforms => client_example/25_tensor_transforms}/tensor_transform.cpp (100%) rename {example/64_tensor_transforms => client_example/25_tensor_transforms}/tensor_transform_using_wrapper.cpp (74%) create mode 100644 docs/wrapper.rst delete mode 100644 example/64_tensor_transforms/CMakeLists.txt rename example/64_tensor_transforms/tensor_transform_wrapper.hpp => include/ck/wrapper/layout.hpp (68%) create mode 100644 include/ck/wrapper/layout_utils.hpp create mode 100644 test/wrapper/CMakeLists.txt create mode 100644 test/wrapper/test_layout.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e46a4ab4b..3da22fc790 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ None - Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) - Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) - Support for Batched Gemm DL (#732) +- Introduce wrapper sublibrary (limited functionality) (#1071) ### Changes - Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) diff --git a/client_example/25_tensor_transforms/CMakeLists.txt b/client_example/25_tensor_transforms/CMakeLists.txt new file mode 100644 index 0000000000..d1543fb0ef --- /dev/null +++ b/client_example/25_tensor_transforms/CMakeLists.txt @@ -0,0 +1,4 @@ +add_executable(client_tensor_transform tensor_transform.cpp) +target_link_libraries(client_tensor_transform PRIVATE composable_kernel::device_other_operations) +add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp) +target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations) diff --git a/example/64_tensor_transforms/tensor_transform.cpp b/client_example/25_tensor_transforms/tensor_transform.cpp similarity index 100% rename from example/64_tensor_transforms/tensor_transform.cpp rename to client_example/25_tensor_transforms/tensor_transform.cpp diff --git a/example/64_tensor_transforms/tensor_transform_using_wrapper.cpp b/client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp similarity index 74% rename from example/64_tensor_transforms/tensor_transform_using_wrapper.cpp rename to client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp index df2449e99d..de9fcde0b4 100644 --- a/example/64_tensor_transforms/tensor_transform_using_wrapper.cpp +++ b/client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp @@ -9,7 +9,7 @@ #include "ck/utility/tuple.hpp" #include "ck/utility/sequence.hpp" -#include "tensor_transform_wrapper.hpp" +#include "ck/wrapper/layout.hpp" using DataType = int; @@ -17,7 +17,7 @@ template void Print1d(const Layout& layout) { std::cout << "Print1d" << std::endl; - for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size(layout); w++) + for(ck::index_t w = 0; w < ck::wrapper::size(layout); w++) { std::cout << layout(ck::make_tuple(w)) << " "; } @@ -28,9 +28,9 @@ template void Print2d(const Layout& layout) { std::cout << "Print2d" << std::endl; - for(ck::index_t h = 0; h < ck::tensor_transform_wrapper::size<0>(layout); h++) + for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++) { - for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++) + for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++) { std::cout << layout(ck::make_tuple(h, w)) << " "; } @@ -43,15 +43,11 @@ template void Print3dCustom(const Layout& layout) { std::cout << "Print3dCustom" << std::endl; - for(ck::index_t d = 0; - d < ck::tensor_transform_wrapper::size<0>(ck::tensor_transform_wrapper::get<0>(layout)); - d++) + for(ck::index_t d = 0; d < ck::wrapper::size<0>(ck::wrapper::get<0>(layout)); d++) { - for(ck::index_t h = 0; - h < ck::tensor_transform_wrapper::size<1>(ck::tensor_transform_wrapper::get<0>(layout)); - h++) + for(ck::index_t h = 0; h < ck::wrapper::size<1>(ck::wrapper::get<0>(layout)); h++) { - for(ck::index_t w = 0; w < ck::tensor_transform_wrapper::size<1>(layout); w++) + for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++) { std::cout << layout(ck::make_tuple(ck::make_tuple(d, h), w)) << " "; } @@ -68,7 +64,7 @@ int main() // Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor) // (dims:4,8 strides:1,4) const auto shape_4x8 = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}); - const auto layout_4x8_s1x4 = ck::tensor_transform_wrapper::make_layout(shape_4x8); + const auto layout_4x8_s1x4 = ck::wrapper::make_layout(shape_4x8); std::cout << "dims:4,8 strides:1,4" << std::endl; Print2d(layout_4x8_s1x4); using Cord1x1Type = ck::Tuple, ck::Number<1>>; @@ -77,10 +73,9 @@ int main() // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor) // dims:4,(2,4) strides:2,(1,8) - const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); - const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); - const auto layout_4x2x4_s2x1x8 = - ck::tensor_transform_wrapper::make_layout(shape_4x2x4, strides_s2x1x8); + const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); + const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); + const auto layout_4x2x4_s2x1x8 = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8); std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; Print2d(layout_4x2x4_s2x1x8); @@ -92,7 +87,7 @@ int main() const auto strides_s1x4x2x8 = ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::make_tuple(ck::Number<2>{}, ck::Number<8>{})); static const auto layout_2x2x2x4_s1x4x2x8 = - ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8); + ck::wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8); std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl; Print2d(layout_2x2x2x4_s1x4x2x8); @@ -108,7 +103,7 @@ int main() ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::Number<2>{}), ck::Number<8>{}); static const auto layout_2x2x2x4_s1x4x2x8_nested = - ck::tensor_transform_wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested); + ck::wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested); std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl; Print1d(layout_2x2x2x4_s1x4x2x8_nested); diff --git a/docs/doxygen/Doxyfile b/docs/doxygen/Doxyfile index 2594422095..fac9e138e1 100644 --- a/docs/doxygen/Doxyfile +++ b/docs/doxygen/Doxyfile @@ -778,7 +778,9 @@ WARN_LOGFILE = INPUT = ../../include/ck/tensor_operation/gpu/grid \ ../../include/ck/tensor_operation/gpu/block \ ../../include/ck/tensor_operation/gpu/thread \ - ../../library/include/ck/library/utility + ../../library/include/ck/library/utility \ + ../../include/ck/wrapper + # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/docs/index.rst b/docs/index.rst index 51c0c862ae..8c4aaa2b3d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -34,6 +34,7 @@ Current CK library are structured into 4 layers: * "Templated Tile Operators" layer * "Templated Kernel and Invoker" layer * "Instantiated Kernel and Invoker" layer +* "Wrapper for tensor transform operations" * "Client API" layer .. image:: data/ck_layer.png @@ -50,6 +51,7 @@ The following is a list of CK documents in the suggested reading order: tutorial_hello_world dockerhub + wrapper Supported_Primitives_Guide API_Reference_Guide Contributors_Guide diff --git a/docs/wrapper.rst b/docs/wrapper.rst new file mode 100644 index 0000000000..64fb6a4031 --- /dev/null +++ b/docs/wrapper.rst @@ -0,0 +1,54 @@ +=============== +Wrapper +=============== + +------------------------------------- +Description +------------------------------------- + +.. note:: + + The wrapper is under development and its functionality is limited. + + +CK provides a lightweight wrapper for more complex operations implemented in +the library. It allows indexing of nested layouts using a simple interface +(avoiding complex descriptor transformations). + +Example: + +.. code-block:: c + + const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); + const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); + const auto layout = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8); + + std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; + for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++) + { + for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++) + { + std::cout << layout(ck::make_tuple(h, w)) << " "; + } + std::cout << std::endl; + } + +Output:: + + dims:4,(2,4) strides:2,(1,8) + 0 1 8 9 16 17 24 25 + 2 3 10 11 18 19 26 27 + 4 5 12 13 20 21 28 29 + 6 7 14 15 22 23 30 31 + +------------------------------------- +Layout +------------------------------------- + +.. doxygenstruct:: ck::wrapper::Layout + +------------------------------------- +Layout helpers +------------------------------------- + +.. doxygenfile:: layout_utils.hpp diff --git a/example/64_tensor_transforms/CMakeLists.txt b/example/64_tensor_transforms/CMakeLists.txt deleted file mode 100644 index 9d14a410e3..0000000000 --- a/example/64_tensor_transforms/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_example_executable(example_tensor_transform tensor_transform.cpp) -add_example_executable(example_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp) diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index d7b492fe66..75f2693f20 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -166,4 +166,16 @@ __host__ __device__ constexpr auto IsNestedTuple(const Tuple&) return (is_detected::value || ...); } +template +__host__ __device__ constexpr auto TupleDepth(const T&) +{ + return depth; +} + +template +__host__ __device__ constexpr auto TupleDepth(const Tuple&) +{ + return math::max(TupleDepth(Ts{})...); +} + } // namespace ck diff --git a/example/64_tensor_transforms/tensor_transform_wrapper.hpp b/include/ck/wrapper/layout.hpp similarity index 68% rename from example/64_tensor_transforms/tensor_transform_wrapper.hpp rename to include/ck/wrapper/layout.hpp index 71cd6091f8..b337d88a1a 100644 --- a/example/64_tensor_transforms/tensor_transform_wrapper.hpp +++ b/include/ck/wrapper/layout.hpp @@ -3,27 +3,13 @@ #pragma once -#include "ck/ck.hpp" - -#include "ck/utility/number.hpp" -#include "ck/utility/tuple.hpp" -#include "ck/utility/tuple_helper.hpp" -#include "ck/utility/sequence.hpp" -#include "ck/utility/sequence_helper.hpp" -#include "ck/utility/is_detected.hpp" - -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/wrapper/layout_utils.hpp" namespace ck { -namespace tensor_transform_wrapper { +namespace wrapper { /** - * \brief Layout wrapper - * - * \details - * Layout wrapper that performs the tensor descriptor logic. + * \brief Layout wrapper that performs the tensor descriptor logic. * * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t * (dynamic layout). It is possible to pass nested shapes @@ -32,21 +18,19 @@ namespace tensor_transform_wrapper { * (dynamic layout). Stride tuple should be nested if shape tuple is * nested. */ -template > +template struct Layout { private: static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - template - using is_tuple = decltype(std::declval().IsTuple()); - // Generate packed (column-major) strides if not passed template __host__ __device__ constexpr static auto - GenerateColumnMajorPackedStrides(const Tuple& tuple) + GenerateColumnMajorPackedStrides(const Tuple& shape) { + const auto unrolled_shape = UnrollNestedTuple(shape); return generate_tuple( [&](auto i) { if constexpr(i.value == 0) @@ -56,10 +40,10 @@ struct Layout else { return TupleReduce([](auto x, auto y) { return x * y; }, - tuple); + unrolled_shape); } }, - Number::Size()>{}); + Number{}); } // Generate LowerDims in Compile-time for MergeTrasform using passed Type @@ -112,8 +96,8 @@ struct Layout // Example shape: (2, (2, 2)), 2, (2, 2) // Unrolled shape: 2, (2, 2), 2, (2, 2) template - __host__ __device__ constexpr static auto UnrollShapeViaIdx(const Tuple& shape, - const Tuple& idx) + __host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple& shape, + const Tuple& idx) { if constexpr(!IsNestedTuple(Tuple{})) { @@ -125,7 +109,7 @@ struct Layout // Iterate over shape tuple elements: // 1. If corresponding idx element is tuple then return (will be unrolled) // 2. If no, pack in tuple. It will be restored during unroll. - auto unrolled_shape_via_idx = generate_tuple( + auto aligned_shape = generate_tuple( [&](auto i) { if constexpr(is_detected>>::value) @@ -140,8 +124,8 @@ struct Layout Number::Size()>{}); // Unroll and process next step - return UnrollShapeViaIdx(UnrollNestedTuple<0, 1>(unrolled_shape_via_idx), - UnrollNestedTuple<0, 1>(idx)); + return AlignShapeToIdx(UnrollNestedTuple<0, 1>(aligned_shape), + UnrollNestedTuple<0, 1>(idx)); } } @@ -150,27 +134,24 @@ struct Layout DescriptorToMerge& desc) { // Reverse each element in tuple - using ReversedUnrolledShape = decltype(TupleReverse(UnrollNestedTuple(shape))); - const auto merge_elems = ReversedUnrolledShape{}; - + const auto merge_elems = TupleReverse(UnrollNestedTuple(shape)); // Generate reverted indexes (column major traverse) - using MergeElemsSequence = - typename arithmetic_sequence_gen<0, ReversedUnrolledShape::Size(), 1>::type; - const auto lower_dims = make_tuple(MergeElemsSequence::Reverse()); - const auto upper_dims = make_tuple(Sequence<0>{}); + using MergeElemsSequence = typename arithmetic_sequence_gen<0, merge_elems.Size(), 1>::type; + const auto lower_dims = make_tuple(MergeElemsSequence::Reverse()); + const auto upper_dims = make_tuple(Sequence<0>{}); // Merge to 1d return transform_tensor_descriptor( desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims); } - // Merge nested shape dims + // Merge nested shape dims. Merge nested shape dims when idx is also nested. // Input desc shape: 2, 2, 2, 2, 2, 2 // Example idx: 1, 1, 1, 1 // Example shape: 2, (2, 2), 2, (2, 2) // Merged shape: 2, 4, 2, 4 template - __host__ __device__ constexpr static auto - MakeMerges(const Tuple& shape, const Tuple&, DescriptorToMerge& desc) + __host__ __device__ constexpr static auto CreateMergedDescriptor( + const Tuple& shape, const Tuple&, DescriptorToMerge& desc) { const auto transforms = generate_tuple( [&](auto i) { @@ -224,9 +205,9 @@ struct Layout static_assert(Tuple::Size() == Tuple::Size(), "Idx rank and Shape rank must be the same (except 1d)."); // Unroll while IdxDims is nested - const auto unrolled_shape_via_idx = UnrollShapeViaIdx(shape, idx); + const auto aligned_shape = AlignShapeToIdx(shape, idx); // Transform correct form of shape - return MakeMerges(unrolled_shape_via_idx, UnrollNestedTuple(idx), descriptor_); + return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), descriptor_); } } @@ -234,26 +215,21 @@ struct Layout __host__ __device__ static auto MakeNaiveDescriptor(const LayoutShape& shape, const LayoutStrides& strides) { - const auto unrolled_shape = UnrollNestedTuple(shape); - - if constexpr(ck::is_same_v>) - { - // If shape is packed - const auto column_major_packed_strides = - GenerateColumnMajorPackedStrides(unrolled_shape); - return make_naive_tensor_descriptor(unrolled_shape, column_major_packed_strides); - } - else - { - const auto unrolled_strides = UnrollNestedTuple(strides); - static_assert(unrolled_shape.Size() == unrolled_strides.Size(), - "Size of strides and shape are not consistent."); - return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); - } + const auto unrolled_shape = UnrollNestedTuple(shape); + const auto unrolled_strides = UnrollNestedTuple(strides); + static_assert(unrolled_shape.Size() == unrolled_strides.Size(), + "Size of strides and shape are not consistent."); + return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); } public: - using NaiveDescriptorType = remove_cvref_t; + // If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`. + using DeducedStrides = + std::conditional_t>, + remove_cvref_t, + Strides>; + using NaiveDescriptorType = + remove_cvref_t; /** * \brief Layout constructor. @@ -268,9 +244,9 @@ struct Layout // Construct if runtime mode if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime()) { - // Keep only shape, strides are not need for transforms shape_ = shape; - descriptor_ = MakeNaiveDescriptor(shape, strides); + strides_ = strides; + descriptor_ = MakeNaiveDescriptor(shape_, strides_); } } @@ -279,7 +255,8 @@ struct Layout if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime()) { shape_ = shape; - descriptor_ = MakeNaiveDescriptor(shape, Strides{}); + strides_ = GenerateColumnMajorPackedStrides(shape_); + descriptor_ = MakeNaiveDescriptor(shape_, strides_); } } @@ -338,7 +315,7 @@ struct Layout * * \return Calculated size. */ - __host__ __device__ constexpr index_t GetLength() const + __host__ __device__ constexpr index_t GetLengths() const { const auto unrolled_shape = UnrollNestedTuple(shape_); return TupleReduce([](auto x, auto y) { return x * y; }, @@ -346,80 +323,24 @@ struct Layout } /** - * \brief Dimension getter. + * \brief Shape getter. * - * \tparam IDim Dimension idx. - * \return Calculated size. + * \return Shape. */ - template - __host__ __device__ constexpr auto Get() const - { - const auto elem = shape_.At(Number{}); - return elem; - } + __host__ __device__ constexpr Shape GetShape() const { return shape_; } + + /** + * \brief Strides getter. + * + * \return Strides. + */ + __host__ __device__ constexpr DeducedStrides GetStrides() const { return strides_; } private: NaiveDescriptorType descriptor_; Shape shape_; + DeducedStrides strides_; }; -// Layout helpers -// Length getter (product if tuple) -template -__host__ __device__ constexpr index_t size(const Layout& layout) -{ - return layout.template GetLength(); -} - -// Get shape size (product of dims if tuple) -template -__host__ __device__ constexpr index_t size(const Tuple& shape) -{ - using UnrolledShape = decltype(UnrollNestedTuple(shape)); - return TupleReduce<0, UnrolledShape::Size()>([](auto x, auto y) { return x * y; }, - UnrolledShape{}); -} - -// Get dim size (could be returned from get function) -template -__host__ __device__ T constexpr size(const T& dim) -{ - return dim; -} - -// Get layout size (product of shapes) -template -__host__ __device__ constexpr index_t size(const Layout& layout) -{ - return layout.GetLength(); -} - -// Get shape element size -template -__host__ __device__ constexpr index_t size(const Tuple& shape) -{ - return size(shape.At(Number{})); -} - -// Dim getter (tuple if tuple) -template -__host__ __device__ constexpr auto get(const Layout& layout) -{ - return layout.template Get(); -} - -template -__host__ __device__ constexpr Layout make_layout(const Shape& shape, - const Strides& strides) -{ - return Layout(shape, strides); -} - -template -__host__ __device__ constexpr Layout make_layout(const Shape& shape) -{ - return Layout(shape); -} - -} // namespace tensor_transform_wrapper +} // namespace wrapper } // namespace ck diff --git a/include/ck/wrapper/layout_utils.hpp b/include/ck/wrapper/layout_utils.hpp new file mode 100644 index 0000000000..fac8f33854 --- /dev/null +++ b/include/ck/wrapper/layout_utils.hpp @@ -0,0 +1,321 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/tuple_helper.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/utility/is_detected.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" + +namespace ck { +namespace wrapper { + +// Disable from doxygen docs generation +/// @cond +// forward declaration +template > +struct Layout; + +template +using is_tuple = decltype(std::declval().IsTuple()); +/// @endcond + +// make_* +/** + * \brief Make layout function. + * + * \tparam Shape Shape for layout. + * \tparam Strides Strides for layout. + * \return Constructed layout. + */ +template +__host__ __device__ constexpr Layout make_layout(const Shape& shape, + const Strides& strides) +{ + return Layout(shape, strides); +} + +/** + * \brief Make layout function with packed strides + * (column-major). + * + * \tparam Shape Shape for layout. + * \return Constructed layout. + */ +template +__host__ __device__ constexpr Layout make_layout(const Shape& shape) +{ + return Layout(shape); +} + +// Layout helpers +// get +/** + * \brief Get element from tuple (Shape/Strides/Idxs). + * + * \tparam idx Index to lookup. + * \param tuple Tuple to lookup. + * \return Requsted element. + */ +template +__host__ __device__ constexpr auto get(const Tuple& tuple) +{ + return tuple.At(Number{}); +} + +/** + * \brief Get sub layout. + * + * \tparam idx Index to lookup. + * \param layout Layout to create sub layout. + * \return Requsted sub layout. + */ +template +__host__ __device__ constexpr auto get(const Layout& layout) +{ + const auto new_shape = get(layout.GetShape()); + static_assert(is_detected::value, + "Shape of sub layout must be tuple"); + if constexpr(is_same_v>) + { + // If stride not passed, create without strides + return make_layout(new_shape); + } + else + { + const auto new_strides = get(layout.GetStrides()); + static_assert(is_detected::value, + "Strides of sub layout must be tuple"); + return make_layout(new_shape, new_strides); + } +} + +/** + * \brief Hierarchical get. + * + * \tparam Idxs Indexes to lookup. + * \param elem Element to lookup. + * \return Requsted element. + */ +template +__host__ __device__ constexpr auto get(const T& elem) +{ + return get(get(elem)); +} + +// size +/** + * \brief Length get (product if tuple). + * + * \tparam idx Index to lookup. + * \param layout Layout to get Shape. + * \return Requsted length. + */ +template +__host__ __device__ constexpr index_t size(const Layout& layout) +{ + return layout.template GetLength(); +} + +/** + * \brief Shape size (product of dims). + * + * \param shape Shape to lookup. + * \return Requsted size. + */ +template +__host__ __device__ constexpr index_t size(const Tuple& shape) +{ + const auto unrolled_shape = UnrollNestedTuple(shape); + return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; }, + unrolled_shape); +} + +// Get dim size (could be returned from get function) +/** + * \private + */ +template +__host__ __device__ T constexpr size(const T& dim) +{ + return dim; +} + +/** + * \brief Layout size (product of dims). + * + * \param layout Layout to calculate shape size. + * \return Requsted size. + */ +template +__host__ __device__ constexpr index_t size(const Layout& layout) +{ + return layout.GetLengths(); +} + +/** + * \brief Length get from tuple (product if tuple). + * + * \tparam idx Index to lookup. + * \param tuple Tuple to lookup. + * \return Requsted length. + */ +template +__host__ __device__ constexpr index_t size(const Tuple& tuple) +{ + return size(tuple.At(Number{})); +} + +/** + * \brief Hierarchical size. + * + * \tparam Idxs Indexes to lookup. + * \param elem Element to lookup. + * \return Requsted element. + */ +template +__host__ __device__ constexpr auto size(const T& elem) +{ + return size(get(elem)); +} + +// rank +/** + * \brief Get layout rank (num elements in shape). + * + * \param layout Layout to calculate rank. + * \return Requsted rank. + */ +template +__host__ __device__ constexpr auto rank([[maybe_unused]] const Layout& layout) +{ + return Shape::Size(); +} + +/** + * \brief Get tuple rank (num elements in tuple). + * Return 1 if scalar passed. + * + * \param tuple Tuple to calculate rank. + * \return Requsted rank. + */ +template +__host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple& tuple) +{ + return Tuple::Size(); +} + +/** + * \private + */ +template +__host__ __device__ constexpr index_t rank(const Number&) +{ + return 1; +} + +/** + * \private + */ +__host__ __device__ constexpr index_t rank(const index_t&) { return 1; } + +/** + * \brief Hierarchical rank. + * + * \tparam Idxs Indexes to lookup. + * \param elem Element to lookup. + * \return Requsted rank. + */ +template +__host__ __device__ constexpr auto rank(const T& elem) +{ + return rank(get(elem)); +} + +// depth +/** + * \brief Get depth of the layout shape (return 0 if scalar). + * + * \param layout Layout to calculate depth. + * \return Requsted depth. + */ +template +__host__ __device__ constexpr auto depth(const Layout& layout) +{ + return TupleDepth(layout.GetShape()); +} + +/** + * \brief Get depth of the tuple. (return 0 if scalar) + * + * \param tuple Tuple to calculate depth. + * \return Requsted depth. + */ +template +__host__ __device__ constexpr auto depth(const Tuple& tuple) +{ + return TupleDepth(tuple); +} + +/** + * \private + */ +template +__host__ __device__ constexpr index_t depth(const Number&) +{ + return 0; +} + +/** + * \private + */ +__host__ __device__ constexpr index_t depth(const index_t&) { return 0; } + +/** + * \brief Hierarchical depth. + * + * \tparam Idxs Indexes to lookup. + * \param elem Element to lookup. + * \return Requsted depth. + */ +template +__host__ __device__ constexpr auto depth(const T& elem) +{ + return depth(get(elem)); +} + +/** + * \brief Get Layout strides. + * + * \param layout Layout to get strides. + * \return Requsted strides. + */ +template +__host__ __device__ constexpr auto stride(const Layout& layout) +{ + return layout.GetStrides(); +} + +/** + * \brief Get Layout shape. + * + * \param layout Layout to get shape. + * \return Requsted shape. + */ +template +__host__ __device__ constexpr auto shape(const Layout& layout) +{ + return layout.GetShape(); +} + +} // namespace wrapper +} // namespace ck diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4aaa5fcfa5..b325a3a7f8 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -149,6 +149,7 @@ add_subdirectory(batched_gemm_multi_d) add_subdirectory(grouped_convnd_bwd_data) add_subdirectory(conv_tensor_rearrange) add_subdirectory(transpose) +add_subdirectory(wrapper) if(GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() diff --git a/test/wrapper/CMakeLists.txt b/test/wrapper/CMakeLists.txt new file mode 100644 index 0000000000..e25ef176dd --- /dev/null +++ b/test/wrapper/CMakeLists.txt @@ -0,0 +1,2 @@ +add_gtest_executable(test_layout test_layout.cpp) +target_link_libraries(test_layout PRIVATE utility) diff --git a/test/wrapper/test_layout.cpp b/test/wrapper/test_layout.cpp new file mode 100644 index 0000000000..7d09696fbb --- /dev/null +++ b/test/wrapper/test_layout.cpp @@ -0,0 +1,481 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/wrapper/layout.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" + +class TestWrapperLayout : public ::testing::Test +{ + protected: + static constexpr auto I0 = ck::Number<0>{}; + static constexpr auto I1 = ck::Number<1>{}; + + template + void Run(Desc& desc, + Desc1d& desc_1d, + LayoutRuntime& layout_runtime, + LayoutCompiletime& layout_compiletime, + const std::vector& idxs) + { + // 1d check + EXPECT_EQ(desc_1d.GetLength(I0), ck::wrapper::size(layout_runtime)); + // Check layout compiletime and runtime result consistency + EXPECT_EQ(ck::wrapper::size(layout_runtime), ck::wrapper::size(layout_compiletime)); + + for(ck::index_t i = 0; i < desc_1d.GetLength(I0); i++) + { + const ck::index_t layout_runtime_offset_1d = layout_runtime(ck::make_tuple(i)); + const ck::index_t layout_compiletime_offset_1d = layout_compiletime(ck::make_tuple(i)); + const ck::index_t desc_offset_1d = desc_1d.CalculateOffset(ck::make_tuple(i)); + EXPECT_EQ(layout_runtime_offset_1d, desc_offset_1d); + EXPECT_EQ(layout_compiletime_offset_1d, layout_runtime_offset_1d); + } + // size(layout)-d check, don't check if access is hierarchical + if constexpr(!IsNestedTuple(Idxs{})) + { + ck::static_for<0, Idxs::Size(), 1>{}([&](auto d) { + EXPECT_EQ(desc.GetLength(ck::Number{}), ck::wrapper::size(layout_runtime)); + EXPECT_EQ(ck::wrapper::size(layout_runtime), + ck::wrapper::size(layout_compiletime)); + }); + } + for(const auto idx : idxs) + { + const ck::index_t layout_runtime_offset = layout_runtime(idx); + const ck::index_t layout_compiletime_offset = layout_compiletime(idx); + const ck::index_t desc_offset = + desc.CalculateOffset(UnrollNestedTuple(idx)); // Unroll if nested + EXPECT_EQ(layout_runtime_offset, desc_offset); + EXPECT_EQ(layout_runtime_offset, layout_compiletime_offset); + } + } +}; + +TEST_F(TestWrapperLayout, 2d) +{ + // dims:(4, 3) strides:(1, 4) + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + constexpr ck::index_t s1 = 1; + constexpr ck::index_t s0 = 4; + const auto desc = + ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})); + // Reverse due to column major + const auto desc_1d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1))), + ck::make_tuple(ck::Sequence<1, 0>{}), + ck::make_tuple(ck::Sequence<0>{})); + const auto layout_runtime = ck::wrapper::make_layout(ck::make_tuple(d1, d0)); + const auto layout_compiletime = + ck::wrapper::make_layout(ck::make_tuple(ck::Number{}, ck::Number{})); + std::vector> idxs; + + for(ck::index_t h = 0; h < d1; h++) + { + for(ck::index_t w = 0; w < d0; w++) + { + idxs.emplace_back(h, w); + } + } + + this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs); +} + +TEST_F(TestWrapperLayout, 3d_nested) +{ + // dims:((2, 3), 4, 3) strides:((2, 4), 12, 48) + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + constexpr ck::index_t s3 = 2; + constexpr ck::index_t s2 = 4; + constexpr ck::index_t s1 = 12; + constexpr ck::index_t s0 = 48; + const auto desc = ck::make_naive_tensor_descriptor( + ck::make_tuple(ck::Number{}, ck::Number{}, ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}, ck::Number{}, ck::Number{})); + // Reverse due to column major + const auto desc_1d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1, d2, d3))), + ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}), + ck::make_tuple(ck::Sequence<0>{})); + const auto desc_3d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d2, d3)), + ck::make_pass_through_transform(d1), + ck::make_pass_through_transform(d2)), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<3>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{})); + const auto layout_runtime = + ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(d3, d2), d1, d0), + ck::make_tuple(ck::make_tuple(s3, s2), s1, s0)); + const auto layout_compiletime = ck::wrapper::make_layout( + ck::make_tuple( + ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}, ck::Number{}), + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), + ck::Number{}, + ck::Number{})); + std::vector> idxs_3d; + + for(ck::index_t d = 0; d < d2 * d3; d++) + { + for(ck::index_t h = 0; h < d1; h++) + { + for(ck::index_t w = 0; w < d0; w++) + { + idxs_3d.emplace_back(d, h, w); + } + } + } + this->Run(desc_3d, desc_1d, layout_runtime, layout_compiletime, idxs_3d); + + // Check also 4d iteration + std::vector, ck::index_t, ck::index_t>> idxs_4d; + + for(ck::index_t e = 0; e < d3; e++) + { + for(ck::index_t d = 0; d < d2; d++) + { + for(ck::index_t h = 0; h < d1; h++) + { + for(ck::index_t w = 0; w < d0; w++) + { + idxs_4d.emplace_back(ck::make_tuple(e, d), h, w); + } + } + } + } + this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs_4d); +} + +TEST_F(TestWrapperLayout, 2d_nested) +{ + // dims:((2, 3), (4, 3)) strides:((2, 4), (48, 12)) + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + constexpr ck::index_t s3 = 2; + constexpr ck::index_t s2 = 4; + constexpr ck::index_t s1 = 48; + constexpr ck::index_t s0 = 12; + const auto desc = ck::make_naive_tensor_descriptor( + ck::make_tuple(ck::Number{}, ck::Number{}, ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}, ck::Number{}, ck::Number{})); + // Reverse due to column major + const auto desc_1d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1, d2, d3))), + ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}), + ck::make_tuple(ck::Sequence<0>{})); + const auto desc_2d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d2, d3)), + ck::make_merge_transform(ck::make_tuple(d0, d1))), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<3, 2>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); + const auto layout_runtime = + ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(d3, d2), ck::make_tuple(d1, d0)), + ck::make_tuple(ck::make_tuple(s3, s2), ck::make_tuple(s1, s0))); + const auto layout_compiletime = ck::wrapper::make_layout( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})), + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}))); + std::vector> idxs_2d; + + for(ck::index_t h = 0; h < d2 * d3; h++) + { + for(ck::index_t w = 0; w < d0 * d1; w++) + { + idxs_2d.emplace_back(h, w); + } + } + this->Run(desc_2d, desc_1d, layout_runtime, layout_compiletime, idxs_2d); + // Check also 4d iteration + std::vector, ck::Tuple>> + idxs_4d; + + for(ck::index_t e = 0; e < d3; e++) + { + for(ck::index_t d = 0; d < d2; d++) + { + for(ck::index_t h = 0; h < d1; h++) + { + for(ck::index_t w = 0; w < d0; w++) + { + idxs_4d.emplace_back(ck::make_tuple(e, d), ck::make_tuple(h, w)); + } + } + } + } + this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs_4d); +} + +TEST_F(TestWrapperLayout, 3d_double_nested) +{ + // dims:(((2, 2), 3), (4, 3)) strides:(((2, 4), 8), (96, 24)) + constexpr ck::index_t d4 = 2; + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + constexpr ck::index_t s4 = 2; + constexpr ck::index_t s3 = 4; + constexpr ck::index_t s2 = 8; + constexpr ck::index_t s1 = 96; + constexpr ck::index_t s0 = 24; + const auto desc = ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number{}, + ck::Number{}, + ck::Number{}, + ck::Number{}, + ck::Number{}), + ck::make_tuple(ck::Number{}, + ck::Number{}, + ck::Number{}, + ck::Number{}, + ck::Number{})); + // Reverse due to column major + const auto desc_1d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d0, d1, d2, d3, d4))), + ck::make_tuple(ck::Sequence<4, 3, 2, 1, 0>{}), + ck::make_tuple(ck::Sequence<0>{})); + const auto desc_3d = transform_tensor_descriptor( + desc, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d3, d4)), + ck::make_pass_through_transform(d2), + ck::make_merge_transform(ck::make_tuple(d0, d1))), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<4, 3>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{})); + const auto desc_2d = transform_tensor_descriptor( + desc_3d, + ck::make_tuple(ck::make_merge_transform(ck::make_tuple(d2, d3 * d4)), + ck::make_pass_through_transform(d1 * d0)), + ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}), + ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); + const auto layout_runtime = ck::wrapper::make_layout( + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0)), + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, s3), s2), ck::make_tuple(s1, s0))); + const auto layout_compiletime = ck::wrapper::make_layout( + ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})), + ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}))); + std::vector> idxs_2d; + + for(ck::index_t h = 0; h < d2 * d3 * d4; h++) + { + for(ck::index_t w = 0; w < d0 * d1; w++) + { + idxs_2d.emplace_back(h, w); + } + } + this->Run(desc_2d, desc_1d, layout_runtime, layout_compiletime, idxs_2d); + // Check also 3d iteration + std::vector, ck::index_t>> idxs_3d; + + for(ck::index_t d = 0; d < d3 * d4; d++) + { + for(ck::index_t h = 0; h < d2; h++) + { + for(ck::index_t w = 0; w < d1 * d0; w++) + { + idxs_3d.emplace_back(ck::make_tuple(d, h), w); + } + } + } + this->Run(desc_3d, desc_1d, layout_runtime, layout_compiletime, idxs_3d); + // Check also 5d iteration + std::vector, ck::index_t>, + ck::Tuple>> + idxs_5d; + + for(ck::index_t f = 0; f < d4; f++) + { + for(ck::index_t e = 0; e < d3; e++) + { + for(ck::index_t d = 0; d < d2; d++) + { + for(ck::index_t h = 0; h < d1; h++) + { + for(ck::index_t w = 0; w < d0; w++) + { + idxs_5d.emplace_back(ck::make_tuple(ck::make_tuple(f, e), d), + ck::make_tuple(h, w)); + } + } + } + } + } + this->Run(desc, desc_1d, layout_runtime, layout_compiletime, idxs_5d); +} + +TEST(TestLayoutHelpers, SizeAndGet) +{ + // dims:(((2, 2), 3), (4, 3)) + constexpr ck::index_t d4 = 2; + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + const auto layout_runtime = ck::wrapper::make_layout( + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0))); + const auto layout_compiletime = ck::wrapper::make_layout(ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}))); + + // Size of layout + EXPECT_EQ(ck::wrapper::size(layout_runtime), d4 * d3 * d2 * d1 * d0); + EXPECT_EQ(ck::wrapper::size(layout_compiletime), d4 * d3 * d2 * d1 * d0); + + // Size of dims + EXPECT_EQ(ck::wrapper::size<0>(layout_runtime), d4 * d3 * d2); + EXPECT_EQ(ck::wrapper::size<0>(layout_compiletime), d4 * d3 * d2); + EXPECT_EQ(ck::wrapper::size<1>(layout_runtime), d1 * d0); + EXPECT_EQ(ck::wrapper::size<1>(layout_compiletime), d1 * d0); + + // Access through new layout (using get with layout object) + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(layout_runtime)), d4 * d3); + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(layout_compiletime)), d4 * d3); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_runtime)), d2); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_compiletime)), d2); + + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_runtime))), d4); + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_compiletime))), + d4); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_runtime))), d3); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(ck::wrapper::get<0>(layout_compiletime))), + d3); + + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_runtime)), d2); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<0>(layout_compiletime)), d2); + + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<1>(layout_runtime)), d1); + EXPECT_EQ(ck::wrapper::size<0>(ck::wrapper::get<1>(layout_compiletime)), d1); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<1>(layout_runtime)), d0); + EXPECT_EQ(ck::wrapper::size<1>(ck::wrapper::get<1>(layout_compiletime)), d0); +} + +TEST(TestLayoutHelpers, DepthAndRank) +{ + // dims:(((2, 2), 3), (4, 3)) + constexpr ck::index_t d4 = 2; + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + const auto layout_runtime = ck::wrapper::make_layout( + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0))); + const auto layout_compiletime = ck::wrapper::make_layout(ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}))); + + EXPECT_EQ(ck::wrapper::depth(layout_runtime), 3); + EXPECT_EQ(ck::wrapper::depth(layout_compiletime), 3); + EXPECT_EQ(ck::wrapper::depth(ck::make_tuple(ck::make_tuple(d4, d3), d2)), 2); + // Check for integer + EXPECT_EQ(ck::wrapper::depth(d0), 0); + + EXPECT_EQ(ck::wrapper::rank(layout_runtime), 2); + EXPECT_EQ(ck::wrapper::rank(layout_compiletime), 2); + EXPECT_EQ(ck::wrapper::rank(ck::make_tuple(ck::make_tuple(d4, d3), d2)), 2); + // Check for integer + EXPECT_EQ(ck::wrapper::rank(d0), 1); +} + +TEST(TestLayoutHelpers, ShapeAndStrides) +{ + // dims:(((2, 2), 3), (4, 3)) + constexpr ck::index_t d4 = 2; + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + constexpr ck::index_t s4 = 2; + constexpr ck::index_t s3 = 4; + constexpr ck::index_t s2 = 8; + constexpr ck::index_t s1 = 96; + constexpr ck::index_t s0 = 24; + const auto shape_compiletime = ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})); + const auto strides_compiletime = ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})); + const auto shape_runtime = + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0)); + const auto strides_runtime = + ck::make_tuple(ck::make_tuple(ck::make_tuple(s4, s3), s2), ck::make_tuple(s1, s0)); + const auto layout_runtime = ck::wrapper::make_layout(shape_runtime, strides_runtime); + const auto layout_compiletime = + ck::wrapper::make_layout(shape_compiletime, strides_compiletime); + + constexpr bool check_compiletime_shape = + std::is_same_v::type, + decltype(shape(layout_compiletime))>; + constexpr bool check_compiletime_strides = + std::is_same_v::type, + decltype(stride(layout_compiletime))>; + constexpr bool check_runtime_shape = + std::is_same_v::type, + decltype(shape(layout_runtime))>; + constexpr bool check_runtime_strides = + std::is_same_v::type, + decltype(stride(layout_runtime))>; + EXPECT_TRUE(check_compiletime_shape); + EXPECT_TRUE(check_compiletime_strides); + EXPECT_TRUE(check_runtime_shape); + EXPECT_TRUE(check_runtime_strides); +} + +TEST(TestLayoutHelpers, Hierarchical) +{ + // dims:(((2, 2), 3), (4, 3)) + constexpr ck::index_t d4 = 2; + constexpr ck::index_t d3 = 2; + constexpr ck::index_t d2 = 3; + constexpr ck::index_t d1 = 4; + constexpr ck::index_t d0 = 3; + const auto runtime_shape = + ck::make_tuple(ck::make_tuple(ck::make_tuple(d4, d3), d2), ck::make_tuple(d1, d0)); + const auto layout_runtime = ck::wrapper::make_layout(runtime_shape); + const auto layout_compiletime = ck::wrapper::make_layout(ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number{}, ck::Number{}), ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{}))); + + EXPECT_EQ((ck::wrapper::rank<0, 0>(runtime_shape)), 2); + EXPECT_EQ((ck::wrapper::rank<0, 0>(layout_runtime)), 2); + EXPECT_EQ((ck::wrapper::rank<0, 0>(layout_compiletime)), 2); + + EXPECT_EQ((ck::wrapper::depth<0, 0>(runtime_shape)), 1); + EXPECT_EQ((ck::wrapper::depth<0, 0>(layout_runtime)), 1); + EXPECT_EQ((ck::wrapper::depth<0, 0>(layout_compiletime)), 1); + + EXPECT_EQ((ck::wrapper::size<0, 0>(runtime_shape)), d4 * d3); + EXPECT_EQ((ck::wrapper::size<0, 0>(layout_runtime)), d4 * d3); + EXPECT_EQ((ck::wrapper::size<0, 0>(layout_compiletime)), d4 * d3); + + EXPECT_EQ((ck::wrapper::get<0, 0, 0>(runtime_shape)), d4); +} From 6896c3b0ae3da9adfa3cd4979621cee642257fc3 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 6 Dec 2023 12:48:10 -0800 Subject: [PATCH 10/75] Fix the CI builds using clang++ directly. (#1087) * turn on -O3 compiler flag explicitly * change cmake syntax for CI * modify cmake line breaks in jenkinsfile --- Jenkinsfile | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 8e67f9cc39..d5fbff288f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -768,8 +768,15 @@ pipeline { } agent{ label rocmnode("gfx908 || gfx90a") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" -DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert " """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ + -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ + -DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert " \ + -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') @@ -784,8 +791,12 @@ pipeline { } agent{ label rocmnode("gfx908 || gfx90a") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx908;gfx90a" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') @@ -800,8 +811,12 @@ pipeline { } agent{ label rocmnode("navi21") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx1030" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') @@ -816,8 +831,12 @@ pipeline { } agent{ label rocmnode("navi32") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx1101" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') From 957281ce45025f674c75ee3e318257d9df3a52d7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 7 Dec 2023 10:32:04 -0700 Subject: [PATCH 11/75] Bump rocm-docs-core from 0.29.0 to 0.30.1 in /docs/sphinx (#1090) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.29.0 to 0.30.1. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.29.0...v0.30.1) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index f5ee431e7d..0a65ffc81a 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.29.0 +rocm-docs-core==0.30.1 sphinxcontrib-bibtex==2.6.1 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 0442ae9a2b..01cb32e714 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -96,7 +96,9 @@ pygments==2.14.0 # pydata-sphinx-theme # sphinx pyjwt[crypto]==2.6.0 - # via pygithub + # via + # pygithub + # pyjwt pynacl==1.5.0 # via pygithub pytz==2023.3.post1 @@ -111,7 +113,7 @@ requests==2.28.2 # via # pygithub # sphinx -rocm-docs-core==0.29.0 +rocm-docs-core==0.30.1 # via -r requirements.in six==1.16.0 # via From 33600202c644f64d3596d6340466982895772822 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Thu, 7 Dec 2023 13:39:40 -0600 Subject: [PATCH 12/75] remove imcomplete transpose profiler (#1088) Co-authored-by: Jing Zhang Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- profiler/src/profile_transpose.cpp | 85 ------------------------------ 1 file changed, 85 deletions(-) delete mode 100644 profiler/src/profile_transpose.cpp diff --git a/profiler/src/profile_transpose.cpp b/profiler/src/profile_transpose.cpp deleted file mode 100644 index c239a520d1..0000000000 --- a/profiler/src/profile_transpose.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include - -#include "profiler/profile_transpose_impl.hpp" -#include "profiler_operation_registry.hpp" - -enum struct MatrixLayout -{ - NCDHW, // 0 - NCHWD, // 1 -}; - -enum struct DataType -{ - F32_F32_F32_F32_F32, // 0 - F16_F16_F16_F16_F16, // 1 -}; - -#define OP_NAME "transpose" -#define OP_DESC "Transpose" - -int profile_transpose(int argc, char* argv[]) -{ - if(argc != 15) - { - printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); - printf("arg2: data type (0: fp32; 1: fp16)\n"); - // printf("arg3: matrix layout (NCDHW -> NDCHW);\n"); - printf("arg4: verification (0: no; 1: yes)\n"); - printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg6: print tensor value (0: no; 1: yes)\n"); - printf("arg7: time kernel (0=no, 1=yes)\n"); - printf("arg8 to 13: N, C, D, H, W\n"); - exit(1); - } - - const auto data_type = static_cast(std::stoi(argv[2])); - // const auto layout = static_cast(std::stoi(argv[3])); - const bool do_verification = std::stoi(argv[3]); - const int init_method = std::stoi(argv[4]); - const bool do_log = std::stoi(argv[5]); - const bool time_kernel = std::stoi(argv[6]); - std::vector lengths = std::stoi(argv[7]); - - /**const int N = std::stoi(argv[7]); - const int C = std::stoi(argv[8]); - const int D = std::stoi(argv[9]); - const int H = std::stoi(argv[10]); - const int W = std::stoi(argv[11]);**/ - - using F32 = float; - using F16 = ck::half_t; - - auto profile = [&](auto a_type, auto b_type) { - using ADataType = decltype(a_type); - using BDataType = decltype(b_type); - - bool pass = ck::profiler::profile_transpose_impl( - do_verification, init_method, do_log, time_kernel, lengths); - - return pass ? 0 : 1; - }; - - if(data_type == GemmDataType::F32_F32_F32_F32_F32) - { - return profile(F32{}, F32{}); - } - else if(data_type == GemmDataType::F16_F16_F16_F16_F16) - { - return profile(F16{}, F16{}); - } - else - { - std::cout << "this data_type & layout is not implemented" << std::endl; - - return 1; - } -} - -REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_transpose); From d939411dae1aa0e09fecb466cfdc1e3044085720 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 7 Dec 2023 15:59:34 -0800 Subject: [PATCH 13/75] Switch from ROCmSoftwarePlatform to ROCm org (#1091) * switch from ROCmSoftwarePlatform to ROCm org * replace ROCmSoftwarePlatform with ROCm in few more places --- CITATION.cff | 4 ++-- Jenkinsfile | 10 +++++----- README.md | 2 +- dev-requirements.txt | 4 ++-- include/ck/host_utility/device_prop.hpp | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index d35fe9e587..3813d63812 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -59,9 +59,9 @@ authors: family-names: Zhou - given-names: Jianfeng family-names: Yan -repository-code: 'https://github.com/ROCmSoftwarePlatform/composable_kernel' +repository-code: 'https://github.com/ROCm/composable_kernel' abstract: Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for Machine Learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel progarmming languages, like HIP C++. keywords: - 'CK, Composable Kernel, Tensor Coordinate Transformation' license: MIT -license-url: https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/7fc3ed761aa35709d87c8fbbe41dd368648b3541/LICENSE +license-url: https://github.com/ROCm/composable_kernel/blob/7fc3ed761aa35709d87c8fbbe41dd368648b3541/LICENSE diff --git a/Jenkinsfile b/Jenkinsfile index d5fbff288f..8f661e4780 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -302,7 +302,7 @@ def buildHipClangJob(Map conf=[:]){ def retimage (retimage, image) = getDockerImage(conf) - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 5, unit: 'HOURS') { @@ -355,7 +355,7 @@ def runCKProfiler(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { @@ -487,7 +487,7 @@ def Build_CK(Map conf=[:]){ def retimage def navi_node = 0 - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { @@ -553,7 +553,7 @@ def Build_CK(Map conf=[:]){ sh """#!/bin/bash rm -rf "${params.hipTensor_branch}".zip rm -rf hipTensor-"${params.hipTensor_branch}" - wget https://github.com/ROCmSoftwarePlatform/hipTensor/archive/refs/heads/"${params.hipTensor_branch}".zip + wget https://github.com/ROCm/hipTensor/archive/refs/heads/"${params.hipTensor_branch}".zip unzip -o "${params.hipTensor_branch}".zip """ dir("hipTensor-${params.hipTensor_branch}"){ @@ -605,7 +605,7 @@ def process_results(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) } diff --git a/README.md b/README.md index e5a20f143f..7679607e69 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa 3. Clone CK source code from the GitHub repository and start the build: ```bash - git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git && \ + git clone https://github.com/ROCm/composable_kernel.git && \ cd composable_kernel && \ mkdir build && \ cd build diff --git a/dev-requirements.txt b/dev-requirements.txt index 9e7b9f01e1..d5d91f8c27 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,3 @@ -ROCmSoftwarePlatform/rocm-recipes +ROCm/rocm-recipes RadeonOpenCompute/rocm-cmake@04f694df2a8dc9d7e35fa4dee4ba5fa407ec04f8 --build -danmar/cppcheck@2.9 \ No newline at end of file +danmar/cppcheck@2.9 diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index be2c2395fc..e8dabc9973 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -26,7 +26,7 @@ inline std::string get_device_name() } const std::string raw_name(props.gcnArchName); - // https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 + // https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 static std::map device_name_map = { {"Ellesmere", "gfx803"}, {"Baffin", "gfx803"}, From f83698489109205dfe1780ce63c032b2a27e7434 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 8 Dec 2023 11:07:42 +0100 Subject: [PATCH 14/75] Support broadcast for bias in grouped conv fwd (#1081) * Support broadcast for bias in grouped conv fwd * Fix comment * Comment fixes * Remove GK layout --- ...rouped_conv_fwd_scaleadd_scaleadd_relu.inc | 16 +- example/62_conv_fwd_activ/CMakeLists.txt | 2 + ...aleadd_scaleadd_relu_bcasted_bias_fp16.cpp | 294 ++++++++++++++++++ .../run_convnd_fwd_activ_example.inc | 2 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 32 +- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 3 +- .../gpu/device/tensor_layout.hpp | 6 - .../transform_conv_fwd_to_gemm.hpp | 15 +- .../device_operation_instance_factory.hpp | 6 +- ...olution_forward_scaleadd_scaleadd_relu.hpp | 12 +- ...elu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 8 +- ...relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 8 +- ...relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 8 +- ...elu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp | 8 +- .../conv2d_fwd/conv2d_quantization_common.hpp | 6 +- 15 files changed, 371 insertions(+), 55 deletions(-) create mode 100644 example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp diff --git a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc b/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc index c72c72971d..e8f5529520 100644 --- a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc +++ b/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc @@ -16,6 +16,7 @@ using InLayout = ck::tensor_layout::convolution::NDHWGC; using WeiLayout = ck::tensor_layout::convolution::GKZYXC; using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using BiasLayout = ck::tensor_layout::convolution::G_K; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu; @@ -64,6 +65,9 @@ int execute_conv_fwd_scaleadd_scaleadd_relu() std::array out_lengths{G, N, K, Do, Ho, Wo}; std::array out_strides{ K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + // Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW) + std::array bias_lengths{G, 1, K, 1, 1, 1}; + std::array bias_strides{K, 0, 1, 0, 0, 0}; std::array filter_strides{1, 1, 1}; std::array filter_dilations{1, 1, 1}; @@ -74,13 +78,13 @@ int execute_conv_fwd_scaleadd_scaleadd_relu() SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K); SimpleDeviceMem d0(sizeof(std::tuple_element_t<0, DDataTypes>) * N * Do * Ho * Wo * G * K); - SimpleDeviceMem d1(sizeof(std::tuple_element_t<1, DDataTypes>) * N * Do * Ho * Wo * G * K); + SimpleDeviceMem d1(sizeof(std::tuple_element_t<1, DDataTypes>) * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< NumDimSpatial, InLayout, WeiLayout, - ck::Tuple, + ck::Tuple, OutLayout, InDataType, WeiDataType, @@ -117,8 +121,8 @@ int execute_conv_fwd_scaleadd_scaleadd_relu() in_strides, wei_lengths, wei_strides, - {out_lengths, out_lengths}, - {out_strides, out_strides}, + {out_lengths, bias_lengths}, + {out_strides, bias_strides}, out_lengths, out_strides, filter_strides, @@ -187,8 +191,8 @@ int execute_conv_fwd_scaleadd_scaleadd_relu() in_strides, wei_lengths, wei_strides, - {out_lengths, out_lengths}, - {out_strides, out_strides}, + {out_lengths, bias_lengths}, + {out_strides, bias_strides}, out_lengths, out_strides, filter_strides, diff --git a/example/62_conv_fwd_activ/CMakeLists.txt b/example/62_conv_fwd_activ/CMakeLists.txt index bb95602416..d1f26bbfe1 100644 --- a/example/62_conv_fwd_activ/CMakeLists.txt +++ b/example/62_conv_fwd_activ/CMakeLists.txt @@ -42,6 +42,8 @@ foreach(gpu IN LISTS GPU_TARGETS) # ScaleAdd ScaleAdd Relu add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) + add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16) set(target 1) endif() endforeach() diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp new file mode 100644 index 0000000000..196636f8b5 --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp @@ -0,0 +1,294 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +using BiasLayout = ck::tensor_layout::convolution::G_K; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + +using OutElementOp = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; + +namespace { +// Use custom implementation to pass two more tensors for post op +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t NumDs = 2; + const ck::index_t G = out_g_n_k_wos_desc.GetLengths()[0]; + const ck::index_t K = out_g_n_k_wos_desc.GetLengths()[2]; + + // Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW) + std::array bias_g_k_lengths; + std::array bias_g_k_strides; + // Fill other lenghts than G,K with 1 and strides with 0 + bias_g_k_lengths.fill(1); + bias_g_k_strides.fill(0); + bias_g_k_lengths[0] = G; + bias_g_k_lengths[2] = K; + bias_g_k_strides[0] = K; // stride to G + bias_g_k_strides[2] = 1; // stride to K + const auto broadcasted_bias_desc = HostTensorDescriptor(bias_g_k_lengths, bias_g_k_strides); + + // y = relu ( alpha1 * conv(x) + alpha2 * z + bias ) + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + std::array, NumDs> d_tensors = {Tensor(out_g_n_k_wos_desc), + Tensor(broadcasted_bias_desc)}; + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + std::cout << "z_tensor: " << d_tensors[0].mDesc << std::endl; + std::cout << "bias_tensor: " << d_tensors[1].mDesc << std::endl; + + // Make sure that we allocated only G * K values for bias + assert(static_cast(d_tensors[1].mData.size()) == G * K); + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d_tensors[0].GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d_tensors[1].GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + d_tensors[0].GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + d_tensors[1].GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem z_buf(sizeof(OutDataType) * d_tensors[0].mDesc.GetElementSpaceSize()); + DeviceMem bias_buf(sizeof(OutDataType) * d_tensors[1].mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + z_buf.ToDevice(d_tensors[0].mData.data()); + bias_buf.ToDevice(d_tensors[1].mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + const std::array ds = {z_buf.GetDeviceBuffer(), bias_buf.GetDeviceBuffer()}; + + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + ds, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, NumDs>{ + e_g_n_k_wos_lengths, bias_g_k_lengths}, + std::array, NumDs>{ + e_g_n_k_wos_strides, bias_g_k_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops() + G * K + + conv_param.GetOutputByte() / sizeof(OutDataType); + std::size_t num_btype = conv_param.GetByte() + + G * K * sizeof(OutDataType) + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = + ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + {}, + {}, + d_tensors); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} + +} // namespace + +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc b/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc index 7c20c01066..aa547c870a 100644 --- a/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc +++ b/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc @@ -24,7 +24,7 @@ bool run_convnd_fwd_example(int argc, char* argv[]) // Following shapes are selected to avoid overflow. Expect inf in case of // size increase for some elementwise ops. ck::utils::conv::ConvParam conv_param{ - 3, 1, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}; + 3, 2, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}; if(argc == 1) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 26224b5dec..4afef85d8c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -357,15 +357,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle return out_gemmm_gemmn_desc; } + // Shape of Ds and E must be aligned. Strides can be different. + // Pass e_g_n_k_wos_lengths for logical broadcast. static auto MakeDsGridDescriptor_M_N( - const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_lengths, const std::array, NumDTensor>& ds_g_n_k_wos_strides) { return generate_tuple( [&](auto i) { using DLayout = remove_cvref_t>; - return DeviceOp::MakeEGridDescriptor_M_N(ds_g_n_k_wos_lengths[i], + return DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]); }, Number{}); @@ -569,7 +571,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle // D desc ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( - ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); + e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]); }); compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; @@ -916,8 +918,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v) + is_same_v || is_same_v) { const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; @@ -925,6 +926,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { valid = false; } + + if constexpr(is_same_v) + { + // G and K must be the same + if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] || + arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2]) + { + valid = false; + } + } + else + { + // E and D must have the same shape + for(index_t d = 0; d < NDimSpatial + 3; d++) + { + if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d]) + { + valid = false; + } + } + } } else { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 80a5d0e97a..0050a5b281 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -631,8 +631,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v) + is_same_v || is_same_v) { const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; diff --git a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp index b2d141fd61..ecc71ba2f2 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -308,12 +308,6 @@ struct GNDHWK : public BaseTensorLayout static constexpr const char* name = "GNDHWK"; }; -// for output bias -struct GK : public BaseTensorLayout -{ - static constexpr const char* name = "GK"; -}; - // output tensor // packed NWGK/NHWGK/NDHWGK struct NWGK : public BaseTensorLayout diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index 6f546f1d6d..e2f75142d4 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -522,22 +522,21 @@ struct TransformConvFwdToGemm // for output bias template || - is_same_v, + typename std::enable_if, bool>::type = false> - static auto - MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */) + static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides) { - const index_t N = c_g_n_k_wos_lengths[1]; - const index_t K = c_g_n_k_wos_lengths[2]; + const index_t N = c_g_n_k_wos_lengths[1]; + const index_t K = c_g_n_k_wos_lengths[2]; + const index_t KStride = c_g_n_k_wos_strides[2]; const index_t NHoWo = N * ck::accumulate_n( c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); const auto out_gemmm_gemmn_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, I1)); + make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride)); return out_gemmm_gemmn_desc; } diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 89b8b9667f..dc47c7ec1a 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -86,9 +86,9 @@ using NHWGK = ck::tensor_layout::convolution::NHWGK; using NDHWGK = ck::tensor_layout::convolution::NDHWGK; // -using GK = ck::tensor_layout::convolution::G_K; -using GK_Tuple = ck::Tuple; -using GK_GK_Tuple = ck::Tuple; +using G_K = ck::tensor_layout::convolution::G_K; +using GK_Tuple = ck::Tuple; +using GK_GK_Tuple = ck::Tuple; // pointwise functor using PassThrough = ck::tensor_operation::element_wise::PassThrough; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp index dc9f44dc86..efb6266426 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp @@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, BF16, BF16, @@ -43,7 +43,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, F16, F16, @@ -59,7 +59,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, F32, F32, @@ -75,7 +75,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, int8_t, int8_t, @@ -130,7 +130,9 @@ struct DeviceOperationInstanceFactory> op_ptrs; if constexpr(NumDimSpatial == 3 && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + DLayouts::Size() == 2 && is_same_v, NDHWGK> && + is_same_v, G_K>) { #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index c6627a4825..7d2df94ad7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, BF16, BF16, @@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwdDefault>{}); add_device_operation_instances( @@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1P0>{}); add_device_operation_instances( @@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_bf16_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1S1P0>{}); } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp index 627af24d7b..8a09d03967 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, F16, F16, @@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwdDefault>{}); add_device_operation_instances( @@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1P0>{}); add_device_operation_instances( @@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f16_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1S1P0>{}); } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp index 1fd567e360..6966959639 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -13,7 +13,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, F32, F32, @@ -28,7 +28,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwdDefault>{}); add_device_operation_instances( @@ -36,7 +36,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1P0>{}); add_device_operation_instances( @@ -44,7 +44,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_f32_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1S1P0>{}); } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp index dae292891c..2606f69428 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp @@ -12,7 +12,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw std::vector, + ck::Tuple, NDHWGK, int8_t, int8_t, @@ -27,7 +27,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwdDefault>{}); add_device_operation_instances( @@ -35,7 +35,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1P0>{}); add_device_operation_instances( @@ -43,7 +43,7 @@ void add_device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhw device_grouped_conv_fwd_xdl_scaleadd_scaleadd_relu_int8_instances<3, NDHWGC, GKZYXC, - ck::Tuple, + ck::Tuple, NDHWGK, ConvFwd1x1S1P0>{}); } diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp index 711314985a..d46fe090b8 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp @@ -22,13 +22,13 @@ using S = ck::Sequence; using NHWGC = ck::tensor_layout::convolution::NHWGC; using GKYXC = ck::tensor_layout::convolution::GKYXC; using NHWGK = ck::tensor_layout::convolution::NHWGK; -using GK = ck::tensor_layout::convolution::G_K; +using G_K = ck::tensor_layout::convolution::G_K; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Relu = ck::tensor_operation::element_wise::Relu; using TanH = ck::tensor_operation::element_wise::TanH; -using GK_Tuple = ck::Tuple; -using GK_GK_Tuple = ck::Tuple; +using GK_Tuple = ck::Tuple; +using GK_GK_Tuple = ck::Tuple; using I32_Tuple = ck::Tuple; using F32_Tuple = ck::Tuple; using I32_F32_Tuple = ck::Tuple; From b4dcd5803f1dae92467d39c31f176131ce796735 Mon Sep 17 00:00:00 2001 From: Nicolas Macchioni Date: Fri, 8 Dec 2023 11:30:01 -0800 Subject: [PATCH 15/75] Add F8 dtype definition in f16_f8_f16 gemm instances (#1092) --- .../device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp | 1 + .../device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp index 3c9e03b674..38667ad42b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp @@ -16,6 +16,7 @@ namespace tensor_operation { namespace device { namespace instance { +using F8 = ck::f8_t; using F16 = ck::half_t; using F32 = float; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp index aab0af990d..820404e064 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp @@ -16,6 +16,7 @@ namespace tensor_operation { namespace device { namespace instance { +using F8 = ck::f8_t; using F16 = ck::half_t; using F32 = float; From f199035b748331901b7e0d58cbcd88e108bdcadd Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 8 Dec 2023 14:32:37 -0800 Subject: [PATCH 16/75] fix clang format (#1095) --- .../device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp | 2 +- .../device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp index 38667ad42b..b3d1e925df 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instance.cpp @@ -16,7 +16,7 @@ namespace tensor_operation { namespace device { namespace instance { -using F8 = ck::f8_t; +using F8 = ck::f8_t; using F16 = ck::half_t; using F32 = float; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp index 820404e064..9c80995949 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instance.cpp @@ -16,7 +16,7 @@ namespace tensor_operation { namespace device { namespace instance { -using F8 = ck::f8_t; +using F8 = ck::f8_t; using F16 = ck::half_t; using F32 = float; From 89ee47460bedd3e028a6240c2c395023cb233f4c Mon Sep 17 00:00:00 2001 From: Bartlomiej Wroblewski Date: Mon, 11 Dec 2023 17:12:32 +0100 Subject: [PATCH 17/75] Fix IsSupported check in the contraction op (#1066) Current implementation of IsSupported method in contraction ops does not cover a lot of possible cases in which ScalarPerVector cannot really be used to read A, B or D, or write E. This PR extends both the regular and multiABD contraction ops with improved checks and also adds new instances with smaller values of ScalarPerVector to support instances that are not supported by other instances. --- ..._contraction_multiple_abd_xdl_cshuffle.hpp | 153 ++++++++-------- ...ce_contraction_multiple_d_xdl_cshuffle.hpp | 163 ++++++++---------- .../device/impl/device_contraction_utils.hpp | 87 ++++++++++ .../device_contraction_instance.hpp | 24 ++- 4 files changed, 261 insertions(+), 166 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp index 29d7a2b949..0c8e11a17b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -500,22 +501,29 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle // for sanity check of vector memory access for(index_t i = 0; i < NumATensor; ++i) { - a_mz_stride_[i] = a_ms_ks_strides[i][NumDimM - 1]; - a_kz_stride_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1]; + as_mz_consecutive_[i] = a_ms_ks_strides[i][NumDimM - 1] == 1; + as_kz_consecutive_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1] == 1; + as_max_read_elems_[i] = + CalculateMaxRead(a_ms_ks_lengths[i], a_ms_ks_strides[i]); } for(index_t i = 0; i < NumBTensor; ++i) { - b_nz_stride_[i] = b_ns_ks_strides[i][NumDimN - 1]; - b_kz_stride_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1]; + bs_nz_consecutive_[i] = b_ns_ks_strides[i][NumDimN - 1] == 1; + bs_kz_consecutive_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1] == 1; + bs_max_read_elems_[i] = + CalculateMaxRead(b_ns_ks_lengths[i], b_ns_ks_strides[i]); } for(index_t i = 0; i < NumDTensor; ++i) { - ds_nz_stride_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1]; + ds_nz_consecutive_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1; + ds_max_read_elems_[i] = + CalculateMaxRead(d_ms_ns_lengths[i], d_ms_ns_strides[i]); } - e_nz_stride_ = e_ms_ns_stride[NumDimM + NumDimN - 1]; + e_nz_consecutive_ = e_ms_ns_stride[NumDimM + NumDimN - 1] == 1; + e_max_write_elems_ = CalculateMaxRead(e_ms_ns_length, e_ms_ns_stride); } // pointers @@ -545,16 +553,19 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle BElementwiseOperation b_element_op_; CDEElementwiseOperation cde_element_op_; - // Strides for the last M/N/K dimensions of A/B/Ds/E - // for sanity check of vector load/store - std::array a_mz_stride_; - std::array a_kz_stride_; + // Describe whether the last part of a given dimension of A/B/D/E is consecutive + // in the memory or not. + std::array as_mz_consecutive_; + std::array as_kz_consecutive_; + std::array bs_nz_consecutive_; + std::array bs_kz_consecutive_; + std::array ds_nz_consecutive_; + bool e_nz_consecutive_; - std::array b_nz_stride_; - std::array b_kz_stride_; - - std::array ds_nz_stride_; - index_t e_nz_stride_; + std::array as_max_read_elems_; + std::array bs_max_read_elems_; + std::array ds_max_read_elems_; + index_t e_max_write_elems_; }; // Invoker @@ -643,73 +654,65 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle // check vector load/store { - bool all_valid = true; - + bool valid_as_access = true; static_for<0, NumATensor, 1>{}([&](auto i) { - // vector memory access of A: could be on M or AK1 dimension - if constexpr(ABlockTransferSrcVectorDim == 1) + const bool valid_a_vector_size = + arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0; + const bool valid_a_access_dim_m = + ABlockTransferSrcVectorDim == 1 && arg.as_mz_consecutive_[i]; + const bool valid_a_access_dim_k = + ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i]; + const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; + if(!(valid_a_vector_size && valid_a_access_dim)) { - if(!(arg.a_mz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I1) % - ABlockTransferSrcScalarPerVector == - 0)) - { - all_valid = false; - } - } - else - { - if(!(arg.a_kz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I2) % - ABlockTransferSrcScalarPerVector == - 0)) - { - all_valid = false; - } + valid_as_access = false; } }); - - // vector memory access of B: could be on N or BK1 dimension - static_for<0, NumBTensor, 1>{}([&](auto i) { - if constexpr(BBlockTransferSrcVectorDim == 1) - { - if(!(arg.b_nz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I1) % - BBlockTransferSrcScalarPerVector == - 0)) - { - all_valid = false; - } - } - else - { - if(!(arg.b_kz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I2) % - BBlockTransferSrcScalarPerVector == - 0)) - { - all_valid = false; - } - } - }); - - // check vector load of Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - if(!(arg.ds_nz_stride_[i] == 1 && - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % - CDEBlockTransferScalarPerVector_NPerBlock == - 0)) - { - all_valid = false; - } - }); - - // vector memory access of E: always on NPerBlock dimension - if(!(arg.e_nz_stride_ == 1 && - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % - CDEBlockTransferScalarPerVector_NPerBlock == - 0)) + if(!valid_as_access) { - all_valid = false; + return false; } - if(!all_valid) + bool valid_bs_access = true; + static_for<0, NumBTensor, 1>{}([&](auto i) { + const bool valid_b_vector_size = + arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0; + const bool valid_b_access_dim_n = + BBlockTransferSrcVectorDim == 1 && arg.bs_nz_consecutive_[i]; + const bool valid_b_access_dim_k = + BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i]; + const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; + if(!(valid_b_vector_size && valid_b_access_dim)) + { + valid_bs_access = false; + } + }); + if(!valid_bs_access) + { + return false; + } + + bool valid_ds_access = true; + static_for<0, NumDTensor, 1>{}([&](auto i) { + const bool valid_d_vector_size = + arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector read of Ds is always on N dimension. + const bool valid_d_access_dim = arg.ds_nz_consecutive_[i]; + if(!(valid_d_vector_size && valid_d_access_dim)) + { + valid_ds_access = false; + } + }); + if(!valid_ds_access) + { + return false; + } + + const bool valid_e_vector_size = + arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector write of E is always on N dimension. + const bool valid_e_access_dim = arg.e_nz_consecutive_; + if(!(valid_e_vector_size && valid_e_access_dim)) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 71ff2ba17d..290abe221a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -13,6 +13,7 @@ #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -183,7 +184,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle return generate_tuple([&](auto i) { return vec[i]; }, num); }; - const auto a_ms_ns_lengths = to_tuple(a_ms_ks_lengths_vec, Number{}); + const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_vec, Number{}); const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number{}); // dimension Ids for M0, M1, ... @@ -194,14 +195,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle typename arithmetic_sequence_gen::type{}; // lengths for M0, M1, ... - const auto mLengths = get_container_subset(a_ms_ns_lengths, mDimIds); + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); // lengths for K0, K1, ... - const auto kLengths = get_container_subset(a_ms_ns_lengths, kDimIds); + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] const auto a_grid_desc_ms_ks = - make_naive_tensor_descriptor(a_ms_ns_lengths, a_ms_ks_strides); + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( @@ -383,7 +384,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const void* p_b_grid, std::array p_ds_grid, void* p_e_grid, - const std::vector& a_ms_ns_lengths, + const std::vector& a_ms_ks_lengths, const std::vector& a_ms_ks_strides, const std::vector& b_ns_ks_lengths, const std::vector& b_ns_ks_strides, @@ -398,7 +399,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle p_b_grid_{static_cast(p_b_grid)}, p_ds_grid_{}, p_e_grid_{static_cast(p_e_grid)}, - a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ns_lengths, a_ms_ks_strides)}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ks_lengths, a_ms_ks_strides)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)}, ds_grid_desc_m_n_{}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)}, @@ -411,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, - cde_element_op_{cde_element_op}, - a_mz_stride_{}, - a_kz_stride_{}, - b_nz_stride_{}, - b_kz_stride_{}, - ds_nz_stride_{}, - e_nz_stride_{} + cde_element_op_{cde_element_op} { // populate pointer, batch stride, desc for Ds static_for<0, NumDTensor, 1>{}([&](auto i) { @@ -448,18 +443,26 @@ struct DeviceContractionMultipleD_Xdl_CShuffle } // for sanity check of vector memory access - a_mz_stride_ = a_ms_ks_strides[NumDimM - 1]; - a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1]; + a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1; + a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1; + a_max_read_elems_ = + CalculateMaxRead(a_ms_ks_lengths, a_ms_ks_strides); - b_nz_stride_ = b_ns_ks_strides[NumDimN - 1]; - b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1]; + b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1; + b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1; + b_max_read_elems_ = + CalculateMaxRead(b_ns_ks_lengths, b_ns_ks_strides); for(index_t i = 0; i < NumDTensor; ++i) { - ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1]; + ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1; + ds_max_read_elems_[i] = + CalculateMaxRead(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); } - e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1]; + e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1; + e_max_write_elems_ = + CalculateMaxRead(e_ms_ns_lengths, e_ms_ns_strides); } void Print() const @@ -499,15 +502,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle BElementwiseOperation b_element_op_; CDEElementwiseOperation cde_element_op_; - // Strides for the last M/N/K dimensions of A/B/Ds/E - // for sanity check of vector load/store - index_t a_mz_stride_; - index_t a_kz_stride_; - index_t b_nz_stride_; - index_t b_kz_stride_; - std::array ds_nz_stride_; - index_t e_mz_stride_; - index_t e_nz_stride_; + // Describe whether the last part of a given dimension of A/B/D/E is consecutive + // in the memory or not. + bool a_mz_consecutive_; + bool a_kz_consecutive_; + bool b_nz_consecutive_; + bool b_kz_consecutive_; + std::array ds_nz_consecutive_; + bool e_nz_consecutive_; + + index_t a_max_read_elems_; + index_t b_max_read_elems_; + std::array ds_max_read_elems_; + index_t e_max_write_elems_; }; // Invoker @@ -616,65 +623,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), "wrong!"); - // vector memory access of A: could be on M or AK1 dimension - if constexpr(ABlockTransferSrcVectorDim == 1) - { - if(!(arg.a_mz_stride_ == 1 && - arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) - { - return false; - } - } - else - { - if(!(arg.a_kz_stride_ == 1 && - arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) - { - return false; - } - } - - // vector memory access of B: could be on N or BK1 dimension - if constexpr(BBlockTransferSrcVectorDim == 1) - { - if(!(arg.b_nz_stride_ == 1 && - arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) - { - return false; - } - } - else - { - if(!(arg.b_kz_stride_ == 1 && - arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) - { - return false; - } - } - - // vector memory access of Ds: always on NPerBlock dimension - bool valid_d_access = true; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - if(!(arg.ds_nz_stride_[i] == 1 && - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % - CDEBlockTransferScalarPerVector_NPerBlock == - 0)) - { - valid_d_access = false; - } - }); - - if(valid_d_access == false) + const bool valid_a_vector_size = + arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0; + const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_; + const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_; + const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; + if(!(valid_a_vector_size && valid_a_access_dim)) { return false; } - // vector memory access of E: always on NPerBlock dimension - if(!(arg.e_nz_stride_ == 1 && - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % - CDEBlockTransferScalarPerVector_NPerBlock == - 0)) + const bool valid_b_vector_size = + arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0; + const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_; + const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_; + const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; + if(!(valid_b_vector_size && valid_b_access_dim)) + { + return false; + } + + bool valid_ds_access = true; + static_for<0, NumDTensor, 1>{}([&](auto i) { + const bool valid_d_vector_size = + arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector read of Ds is always on N dimension. + const bool valid_d_access_dim = arg.ds_nz_consecutive_[i]; + if(!(valid_d_vector_size && valid_d_access_dim)) + { + valid_ds_access = false; + } + }); + if(!valid_ds_access) + { + return false; + } + + const bool valid_e_vector_size = + arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector write of E is always on N dimension. + const bool valid_e_access_dim = arg.e_nz_consecutive_; + if(!(valid_e_vector_size && valid_e_access_dim)) { return false; } @@ -692,7 +681,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const void* p_b, std::array p_ds, void* p_e, - const std::vector& a_ms_ns_lengths, + const std::vector& a_ms_ks_lengths, const std::vector& a_ms_ks_strides, const std::vector& b_ns_ks_lengths, const std::vector& b_ns_ks_strides, @@ -708,7 +697,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle p_b, p_ds, p_e, - a_ms_ns_lengths, + a_ms_ks_lengths, a_ms_ks_strides, b_ns_ks_lengths, b_ns_ks_strides, @@ -729,7 +718,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const void* p_b, std::array p_ds, void* p_e, - const std::vector& a_ms_ns_lengths, + const std::vector& a_ms_ks_lengths, const std::vector& a_ms_ks_strides, const std::vector& b_ns_ks_lengths, const std::vector& b_ns_ks_strides, @@ -745,7 +734,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle p_b, p_ds, p_e, - a_ms_ns_lengths, + a_ms_ks_lengths, a_ms_ks_strides, b_ns_ks_lengths, b_ns_ks_strides, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp new file mode 100644 index 0000000000..0e14b40942 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * Calculates the maximum number of subsequent elements of the fast changing dimension + * that are consecutive in memory. + * + * Example: + * NumDimM = 2, NumDimK = 3 + * A shape = [ 2, 3, 4, 5, 6] + * A strides = [360, 120, 30, 6, 1] + * | M | | K | + * It follows from strides that K is FCD and all the subsequent elements of K are consecutive + * in memory. + * But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be + * consecutive in memory. + * + * Assumes that the dimensions are split into two groups of `NumDim1` and `NumDim2` dimensions. + */ +template +auto CalculateMaxRead(const std::vector& lengths, const std::vector& strides) +{ + if(lengths.size() != NumDim1 + NumDim2) + { + std::ostringstream err; + err << "Incorrect number of lengths in " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + if(strides.size() != NumDim1 + NumDim2) + { + std::ostringstream err; + err << "Incorrect number of strides in " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + // Determine the beginning and end idx of the group representing the FCD. + index_t begin_idx, end_idx; + if(strides[NumDim1 - 1] == 1) + { + begin_idx = 0; + end_idx = NumDim1 - 1; + } + else if(strides[NumDim1 + NumDim2 - 1] == 1) + { + begin_idx = NumDim1; + end_idx = NumDim1 + NumDim2 - 1; + } + else + { + // The dimension consecutive in memory is not the last dimension of any group, so only + // one element can be read/written at once. + return 1; + } + + index_t consecutive_stride = 1; + for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx) + { + if(strides[dim_idx] == consecutive_stride) + { + consecutive_stride *= lengths[dim_idx]; + } + else + { + break; + } + } + const index_t max_subsequent_elems = consecutive_stride; + return max_subsequent_elems; +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp index b43d34d69a..b67119ad19 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp @@ -61,7 +61,11 @@ using device_contraction_kk_instance = std::tuple< DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, ComputeDataType> + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, ComputeDataType>, + // Small scalar per vector + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType> // clang-format on >; @@ -96,7 +100,11 @@ using device_contraction_kn_instance = std::tuple< DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType> + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + // Small scalar per vector + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType> // clang-format on >; @@ -131,7 +139,11 @@ using device_contraction_mk_instance = std::tuple< DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType> + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + // Small scalar per vector + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType> // clang-format on >; @@ -166,7 +178,11 @@ using device_contraction_mn_instance = std::tuple< DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType> + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + // Small scalar per vector + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType> // clang-format on >; From c004e0d99048d76c40da81c4dd2a36921cee0293 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 11 Dec 2023 17:49:27 -0800 Subject: [PATCH 18/75] disabling some fp8 gemm instances to reduce build time (#1084) * disabling some fp8 gemm instances to reduce build time * disable fp8 gemm instances to reduce build time * remove the unused variable * build fp8 gemm default and padded instances separately * fix include pathsc --- ...shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp} | 20 +++----------- .../tensor_operation_instance/gpu/gemm.hpp | 9 +++++-- .../gpu/gemm/CMakeLists.txt | 3 ++- ..._fp8_fp8_fp8_mk_kn_mn_default_instance.cpp | 26 +++++++++++++++++++ ...e_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp | 26 +++++++++++++++++++ 5 files changed, 64 insertions(+), 20 deletions(-) rename library/{src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp => include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp} (96%) create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp similarity index 96% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp rename to library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp index 82eae9f0a2..005cec94ec 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp @@ -25,10 +25,6 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - // Compilation parameters for a[m, k] * b[k, n] = c[m, n] template using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple< @@ -37,7 +33,7 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple< //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // pipeline v1, 1 wave + // pipeline v1, 1 wave DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, @@ -75,7 +71,8 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple< DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1> #endif -#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES +#if 0 + //CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave , DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, @@ -98,17 +95,6 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple< // clang-format on >; -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances{}); - - add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances{}); -} - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp index bbc70f1a5b..626dd7f00a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -345,7 +345,11 @@ void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( std::vector>>& instances); -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances( +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances( std::vector>>& instances); @@ -575,7 +579,8 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index d0bcacbe3c..3532c3f4ba 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -101,7 +101,8 @@ list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp) list(APPEND GEMM_INSTANCES - device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp new file mode 100644 index 0000000000..baa76a74af --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp" + +#ifdef CK_ENABLE_FP8 +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp new file mode 100644 index 0000000000..f16809db28 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp" + +#ifdef CK_ENABLE_FP8 +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif From 6891e4d10965513657d531c3c8c2048aaba34b05 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Wed, 13 Dec 2023 14:27:31 -0600 Subject: [PATCH 19/75] Fix the bugs (#1099) --- include/ck/utility/type_convert.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 70bc6f278c..11db866152 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -182,7 +182,7 @@ inline __host__ __device__ bf8_t f8_convert_sr(half_t x) { #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // convert to float and use native converion - return f8_convert_sr(type_convert(x)); + return f8_convert_sr(type_convert(x)); #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -295,7 +295,7 @@ inline __host__ __device__ bf8_t f8_convert_rne(half_t x) template <> inline __host__ __device__ f8_t type_convert(float x) { -#if defined CK_USE_SR_F8_CONVERSION +#if CK_USE_SR_F8_CONVERSION return f8_convert_sr(x); #else return f8_convert_rne(x); @@ -352,10 +352,10 @@ inline __host__ __device__ half2_t type_convert(float2_t x) template <> inline __host__ __device__ f8_t type_convert(half_t x) { -#if defined CK_USE_SR_F8_CONVERSION +#if CK_USE_SR_F8_CONVERSION return f8_convert_sr(x); #else - return f8_convert_nre(x); + return f8_convert_rne(x); #endif } @@ -376,7 +376,7 @@ inline __host__ __device__ half_t type_convert(f8_t x) template <> inline __host__ __device__ bf8_t type_convert(float x) { -#if defined CK_USE_SR_F8_CONVERSION +#if CK_USE_SR_F8_CONVERSION return f8_convert_sr(x); #else return f8_convert_rne(x); @@ -403,7 +403,7 @@ inline __host__ __device__ float type_convert(bf8_t x) template <> inline __host__ __device__ bf8_t type_convert(half_t x) { -#if defined CK_USE_SR_F8_CONVERSION +#if CK_USE_SR_F8_CONVERSION return f8_convert_sr(x); #else return f8_convert_rne(x); From 3a3b98ef79d967391840a202a8ddf7b3d05ba823 Mon Sep 17 00:00:00 2001 From: Jun Liu Date: Wed, 13 Dec 2023 12:50:15 -0800 Subject: [PATCH 20/75] [Doc][Werror] Fix security alerts and sync with MIOpen (#1085) * fix Werror unused-parameter * sync doc requirements * fix blank space format * fix dependency issue --- docs/sphinx/requirements.txt | 16 ++++++++-------- .../gpu/grid/gridwise_tensor_rearrange.hpp | 2 ++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 01cb32e714..75863c214e 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -16,7 +16,7 @@ beautifulsoup4==4.11.2 # via pydata-sphinx-theme breathe==4.34.0 # via rocm-docs-core -certifi==2022.12.7 +certifi==2023.7.22 # via requests cffi==1.15.1 # via @@ -26,7 +26,7 @@ charset-normalizer==3.1.0 # via requests click==8.1.3 # via sphinx-external-toc -cryptography==40.0.2 +cryptography==41.0.6 # via pyjwt deprecated==1.2.13 # via pygithub @@ -42,7 +42,7 @@ fastjsonschema==2.18.0 # via rocm-docs-core gitdb==4.0.10 # via gitpython -gitpython==3.1.35 +gitpython==3.1.37 # via rocm-docs-core idna==3.4 # via requests @@ -88,9 +88,9 @@ pydata-sphinx-theme==0.13.3 # via # rocm-docs-core # sphinx-book-theme -pygithub==1.58.2 +pygithub==1.58.1 # via rocm-docs-core -pygments==2.14.0 +pygments==2.15.0 # via # accessible-pygments # pydata-sphinx-theme @@ -109,7 +109,7 @@ pyyaml==6.0 # pybtex # rocm-docs-core # sphinx-external-toc -requests==2.28.2 +requests==2.31.0 # via # pygithub # sphinx @@ -141,7 +141,7 @@ sphinx-book-theme==1.0.1 # via rocm-docs-core sphinx-copybutton==0.5.1 # via rocm-docs-core -sphinx-design==0.3.0 +sphinx-design==0.4.1 # via rocm-docs-core sphinx-external-toc==0.3.1 # via rocm-docs-core @@ -163,7 +163,7 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx typing-extensions==4.5.0 # via pydata-sphinx-theme -urllib3==1.26.15 +urllib3==1.26.18 # via requests wrapt==1.15.0 # via deprecated diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp index f77ffff350..9535ca69a9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp @@ -50,7 +50,9 @@ __global__ void ignore = p_in_global; ignore = out_grid_desc; ignore = p_out_global; + ignore = batch_count; ignore = block_2_tile_map; + ignore = compute_ptr_offset_of_batch; #endif } From 281f8369033366669fbabe05ed9622c1370c4a71 Mon Sep 17 00:00:00 2001 From: Lisa Date: Thu, 14 Dec 2023 15:21:18 -0700 Subject: [PATCH 21/75] fix typo (#1067) Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 7679607e69..4889914691 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,6 @@ python3 -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html ``` You can find a list of our developers and contributors on our [Contributors](/CONTRIBUTORS.md) page. -page. ```note If you use CK, cite us as follows: From efaf31061a00a9c17a888ddbf2e273aafe977d5e Mon Sep 17 00:00:00 2001 From: trixirt Date: Thu, 14 Dec 2023 17:26:41 -0800 Subject: [PATCH 22/75] cmake: Add CK_PARALLEL_LINK_JOBS and CK_PARALLEL_COMPILE_JOBS options (#1063) Copied from the llvm-project LLVM_PARALLEL_*_JOBS Concurrent linking can break the build as well as having too many compile jobs for the avaiable memory. These options allow the user to fine tune the build to fit within their machines memory constraints. An example use on linux is COMPILE_JOBS=`cat /proc/cpuinfo | grep -m 1 'cpu cores' | awk '{ print $4 }'` if [ ${COMPILE_JOBS}x = x ]; then COMPILE_JOBS=1 fi BUILD_MEM=4 MEM_KB=0 MEM_KB=`cat /proc/meminfo | grep MemTotal | awk '{ print $2 }'` MEM_MB=`eval "expr ${MEM_KB} / 1024"` MEM_GB=`eval "expr ${MEM_MB} / 1024"` COMPILE_JOBS_MEM=`eval "expr 1 + ${MEM_GB} / ${BUILD_MEM}"` if [ "$COMPILE_JOBS_MEM" -lt "$COMPILE_JOBS" ]; then COMPILE_JOBS=$COMPILE_JOBS_MEM fi LINK_MEM=32 LINK_JOBS=`eval "expr 1 + ${MEM_GB} / ${LINK_MEM}"` cmake -G Ninja -DCK_PARALLEL_LINK_JOBS=$LINK_JOBS -DCK_PARALLEL_COMPILE_JOBS=$COMPILE_JOBS Signed-off-by: Tom Rix --- CMakeLists.txt | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index e780c15657..4e4b9d8d4b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -146,6 +146,33 @@ if(${hip_VERSION_FLAT} GREATER 500723302) add_compile_options(-fno-offload-uniform-block) endif() +# +# Seperate linking jobs from compiling +# Too many concurrent linking jobs can break the build +# Copied from LLVM +set(CK_PARALLEL_LINK_JOBS "" CACHE STRING + "Define the maximum number of concurrent link jobs (Ninja only).") +if(CMAKE_GENERATOR MATCHES "Ninja") + if(CK_PARALLEL_LINK_JOBS) + set_property(GLOBAL APPEND PROPERTY JOB_POOLS link_job_pool=${CK_PARALLEL_LINK_JOBS}) + set(CMAKE_JOB_POOL_LINK link_job_pool) + endif() +elseif(CK_PARALLEL_LINK_JOBS) + message(WARNING "Job pooling is only available with Ninja generators.") +endif() +# Similar for compiling +set(CK_PARALLEL_COMPILE_JOBS "" CACHE STRING + "Define the maximum number of concurrent compile jobs (Ninja only).") +if(CMAKE_GENERATOR MATCHES "Ninja") + if(CK_PARALLEL_COMPILE_JOBS) + set_property(GLOBAL APPEND PROPERTY JOB_POOLS compile_job_pool=${CK_PARALLEL_COMPILE_JOBS}) + set(CMAKE_JOB_POOL_COMPILE compile_job_pool) + endif() +elseif(CK_PARALLEL_COMPILE_JOBS) + message(WARNING "Job pooling is only available with Ninja generators.") +endif() + + option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) From 07092d68f0b13560caf3cbe762a9a799d13cdc0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 15 Dec 2023 12:45:08 +0100 Subject: [PATCH 23/75] Add tensor structure to wrapper (#1098) * Add tensor structure to wrapper * update changelog * Fix names * Comment fixes --- CHANGELOG.md | 2 +- docs/wrapper.rst | 39 ++- include/ck/wrapper/layout.hpp | 178 +++++++--- include/ck/wrapper/tensor.hpp | 314 ++++++++++++++++++ .../ck/wrapper/{ => utils}/layout_utils.hpp | 62 ++-- include/ck/wrapper/utils/tensor_utils.hpp | 290 ++++++++++++++++ test/wrapper/CMakeLists.txt | 2 + test/wrapper/test_layout.cpp | 16 +- test/wrapper/test_tensor.cpp | 205 ++++++++++++ 9 files changed, 1020 insertions(+), 88 deletions(-) create mode 100644 include/ck/wrapper/tensor.hpp rename include/ck/wrapper/{ => utils}/layout_utils.hpp (86%) create mode 100644 include/ck/wrapper/utils/tensor_utils.hpp create mode 100644 test/wrapper/test_tensor.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 3da22fc790..2891b8585b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ None - Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) - Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) - Support for Batched Gemm DL (#732) -- Introduce wrapper sublibrary (limited functionality) (#1071) +- Introduce wrapper sublibrary (limited functionality). (#1071, #1098) ### Changes - Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) diff --git a/docs/wrapper.rst b/docs/wrapper.rst index 64fb6a4031..a2f60b97ae 100644 --- a/docs/wrapper.rst +++ b/docs/wrapper.rst @@ -13,7 +13,7 @@ Description CK provides a lightweight wrapper for more complex operations implemented in the library. It allows indexing of nested layouts using a simple interface -(avoiding complex descriptor transformations). +(avoiding complex descriptor transformations) and memory access (using Tensor). Example: @@ -22,24 +22,31 @@ Example: const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); const auto layout = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8); + + std::array data; + auto tensor = ck::wrapper::make_tensor(&data[0], layout); - std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; - for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++) + for(ck::index_t w = 0; w < size(tensor); w++) { + tensor(w) = w; + } + + // slice() == slice(0, -1) (whole dimension) + auto tensor_slice = tensor(ck::wrapper::slice(1, 3), ck::make_tuple(ck::wrapper::slice(), ck::wrapper::slice())); + std::cout << "dims:2,(2,4) strides:2,(1,8)" << std::endl; + for(ck::index_t h = 0; h < ck::wrapper::size<0>(tensor_slice); h++) { - for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++) + for(ck::index_t w = 0; w < ck::wrapper::size<1>(tensor_slice); w++) { - std::cout << layout(ck::make_tuple(h, w)) << " "; + std::cout << tensor_slice(h, w) << " "; } std::cout << std::endl; } Output:: - dims:4,(2,4) strides:2,(1,8) - 0 1 8 9 16 17 24 25 - 2 3 10 11 18 19 26 27 - 4 5 12 13 20 21 28 29 - 6 7 14 15 22 23 30 31 + dims:2,(2,4) strides:2,(1,8) + 1 5 9 13 17 21 25 29 + 2 6 10 14 18 22 26 30 ------------------------------------- Layout @@ -52,3 +59,15 @@ Layout helpers ------------------------------------- .. doxygenfile:: layout_utils.hpp + +------------------------------------- +Tensor +------------------------------------- + +.. doxygenstruct:: ck::wrapper::Tensor + +------------------------------------- +Tensor helpers +------------------------------------- + +.. doxygenfile:: tensor_utils.hpp diff --git a/include/ck/wrapper/layout.hpp b/include/ck/wrapper/layout.hpp index b337d88a1a..f20d985b49 100644 --- a/include/ck/wrapper/layout.hpp +++ b/include/ck/wrapper/layout.hpp @@ -3,7 +3,7 @@ #pragma once -#include "ck/wrapper/layout_utils.hpp" +#include "ck/wrapper/utils/layout_utils.hpp" namespace ck { namespace wrapper { @@ -25,6 +25,26 @@ struct Layout static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; + // Generate default idxs tuple (idx with all merged nested shapes) + template + __host__ __device__ constexpr static auto GenerateDefaultIdxsTuple(const Tuple&) + { + return generate_tuple( + [&](auto) { + if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime()) + { + // runtime layout + return index_t(0); + } + else + { + // compiletime layout + return I0; + } + }, + Number::Size()>{}); + } + // Generate packed (column-major) strides if not passed template __host__ __device__ constexpr static auto @@ -131,7 +151,7 @@ struct Layout template __host__ __device__ constexpr static auto MakeMerge1d(const Tuple& shape, - DescriptorToMerge& desc) + const DescriptorToMerge& desc) { // Reverse each element in tuple const auto merge_elems = TupleReverse(UnrollNestedTuple(shape)); @@ -144,7 +164,7 @@ struct Layout desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims); } - // Merge nested shape dims. Merge nested shape dims when idx is also nested. + // Merge nested shape dims when corresponding index is also nested. // Input desc shape: 2, 2, 2, 2, 2, 2 // Example idx: 1, 1, 1, 1 // Example shape: 2, (2, 2), 2, (2, 2) @@ -187,14 +207,38 @@ struct Layout return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims); } + template + __host__ __device__ static auto MakeFlattenDescriptor(const LayoutShape& shape, + const LayoutStrides& strides) + { + const auto unrolled_shape = UnrollNestedTuple(shape); + const auto unrolled_strides = UnrollNestedTuple(strides); + static_assert(unrolled_shape.Size() == unrolled_strides.Size(), + "Size of strides and shape are not consistent."); + return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); + } + + // If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`. + using DeducedStrides = + std::conditional_t>, + remove_cvref_t, + Strides>; + using FlattenDescriptorType = + remove_cvref_t; + using Descriptor1dType = + remove_cvref_t; + using DefaultIdxsTupleType = remove_cvref_t; + template - __host__ __device__ constexpr auto TransformDesc(const Tuple& shape, - const Tuple& idx) const + __host__ __device__ constexpr static auto + TransformDesc(const Tuple& shape, + const Tuple& idx, + const FlattenDescriptorType& naive_descriptor) { if constexpr(Tuple::Size() == I1) { // 1d idx path - return MakeMerge1d(shape, descriptor_); + return MakeMerge1d(shape, naive_descriptor); } else { @@ -207,56 +251,53 @@ struct Layout // Unroll while IdxDims is nested const auto aligned_shape = AlignShapeToIdx(shape, idx); // Transform correct form of shape - return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), descriptor_); + return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), naive_descriptor); } } - template - __host__ __device__ static auto MakeNaiveDescriptor(const LayoutShape& shape, - const LayoutStrides& strides) - { - const auto unrolled_shape = UnrollNestedTuple(shape); - const auto unrolled_strides = UnrollNestedTuple(strides); - static_assert(unrolled_shape.Size() == unrolled_strides.Size(), - "Size of strides and shape are not consistent."); - return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); - } + using MergedNestsDescriptorType = remove_cvref_t; public: - // If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`. - using DeducedStrides = - std::conditional_t>, - remove_cvref_t, - Strides>; - using NaiveDescriptorType = - remove_cvref_t; + __host__ __device__ constexpr auto GetElementSpaceSize() const + { + return flatten_descriptor_.GetElementSpaceSize(); + } + __host__ __device__ Layout() = delete; /** * \brief Layout constructor. * * \param shape Shape for layout. * \param strides Strides for layout (optional if tensor is packed). - * \return Layout object. */ - __host__ __device__ Layout() = delete; - __host__ __device__ Layout(const Shape& shape, const Strides& strides) : descriptor_{} + __host__ __device__ constexpr Layout(const Shape& shape, const Strides& strides) + : flatten_descriptor_{}, shape_(shape), strides_(strides) { // Construct if runtime mode - if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime()) + if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime()) { - shape_ = shape; - strides_ = strides; - descriptor_ = MakeNaiveDescriptor(shape_, strides_); + flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_); + descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_); + merged_nests_descriptor_ = + TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_); } } - __host__ __device__ Layout(const Shape& shape) : descriptor_{} + /** + * \brief Layout constructor (with default packed column-major strides). + * + * \param shape Shape for layout. + */ + __host__ __device__ constexpr Layout(const Shape& shape) + : flatten_descriptor_{}, shape_(shape), strides_(GenerateColumnMajorPackedStrides(shape_)) { - if constexpr(!NaiveDescriptorType::IsKnownAtCompileTime()) + if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime()) { - shape_ = shape; - strides_ = GenerateColumnMajorPackedStrides(shape_); - descriptor_ = MakeNaiveDescriptor(shape_, strides_); + flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_); + descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_); + merged_nests_descriptor_ = + TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_); } } @@ -269,7 +310,9 @@ struct Layout template __host__ __device__ constexpr index_t operator()() const { - using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{})); + static_assert(FlattenDescriptorType::IsKnownAtCompileTime(), + "Compiletime operator used on runtime layout."); + using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, FlattenDescriptorType{})); using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{})); return TransformedDesc{}.CalculateOffset(UnrolledIdx{}); } @@ -283,9 +326,22 @@ struct Layout template __host__ __device__ index_t operator()(const Tuple& Idx) const { - // Static to construct transformed_desc only once - static const auto transformed_desc = TransformDesc(shape_, Idx); - return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx)); + if constexpr(!IsNestedTuple(Tuple{}) && Tuple::Size() == 1) + { + // if 1d access + return descriptor_1d_.CalculateOffset(Idx); + } + else if constexpr(!IsNestedTuple(Tuple{}) && Tuple::Size() == Shape::Size()) + { + // if Shape::Size() access (merged nested shapes) + return merged_nests_descriptor_.CalculateOffset(UnrollNestedTuple(Idx)); + } + else + { + // Custom index, need to transform descriptor + const auto transformed_desc = TransformDesc(shape_, Idx, flatten_descriptor_); + return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx)); + } } /** @@ -327,19 +383,51 @@ struct Layout * * \return Shape. */ - __host__ __device__ constexpr Shape GetShape() const { return shape_; } + __host__ __device__ constexpr const Shape& GetShape() const { return shape_; } /** * \brief Strides getter. * * \return Strides. */ - __host__ __device__ constexpr DeducedStrides GetStrides() const { return strides_; } + __host__ __device__ constexpr const DeducedStrides& GetStrides() const { return strides_; } + + /** + * \brief Get default lengths (tuple filled with Shape length elements). + * + * \return Default lengths. + */ + __host__ __device__ constexpr auto GetDefaultLengthsTuple() const + { + return generate_tuple([&](auto i) { return GetLength(); }, Number{}); + } + + /** + * \brief Get default start idx (tuple filled with 0s of the same size as Shape). + * + * \return Default start idx. + */ + __host__ __device__ constexpr auto GetDefaultStartIdxs() const + { + return GenerateDefaultIdxsTuple(shape_); + } + + /** + * \brief Get default descriptor (with the same size as Shape) + * + * \return Default descriptor. + */ + __host__ __device__ constexpr MergedNestsDescriptorType GetDefaultDescriptor() + { + return merged_nests_descriptor_; + } private: - NaiveDescriptorType descriptor_; - Shape shape_; - DeducedStrides strides_; + FlattenDescriptorType flatten_descriptor_; + Descriptor1dType descriptor_1d_; + MergedNestsDescriptorType merged_nests_descriptor_; + const Shape shape_; + const DeducedStrides strides_; }; } // namespace wrapper diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp new file mode 100644 index 0000000000..4ec6498fbc --- /dev/null +++ b/include/ck/wrapper/tensor.hpp @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "utils/tensor_utils.hpp" +#include "utils/layout_utils.hpp" + +namespace ck { +namespace wrapper { + +/** + * \brief Tensor wrapper that performs static and dynamic buffer logic. + * + * \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR). + * \tparam ElementType Element data type. + * \tparam Shape Tensor shape (layout component). + * \tparam Strides Tensor strides (layout component). + * \tparam NumVectors Number of vectors (only for VGPR, SGPR). + * \tparam ScalarPerVector Scalars per vector (only for VGPR, SGPR). + */ +template +struct Tensor +{ + private: + // Check if Tuple contains Slice object + template + constexpr static bool IsSlicing(T&&) + { + return is_detected::value; + } + template + constexpr static bool IsSlicing(Tuple&&) + { + return (IsSlicing(Ts{}) || ...); + } + + // Calculate first index of new tensor after slice + // It is needed to calculate offset for new tensor + template + constexpr auto GetStartIdxForSlicedTensor(const Tuple& idx) const + { + const auto start_idx_for_sliced_tensor = generate_tuple( + [&](auto i) { + constexpr auto num_i = Number{}; + if constexpr(is_detected>>::value) + { + // if tuple then recurrence + return GetStartIdxForSlicedTensor(idx.At(num_i)); + } + else if constexpr(is_detected>>::value) + { + // if slice, return the beginning of the interval + return idx.At(num_i).from_; + } + else + { + // if one dim selected + return idx.At(num_i); + } + }, + Number::Size()>{}); + + return start_idx_for_sliced_tensor; + } + + // Calculate new tensor shape after slice + template + constexpr auto GetShapeFromSlicedTensor(const Tuple& idx, + const ShapeTmpType& shape) const + { + // Pack each value in tuple to remove empty tuples after generation + auto new_shape = generate_tuple( + [&](auto i) { + constexpr auto num_i = Number{}; + if constexpr(is_detected>>::value) + { + if constexpr(!IsSlicing(tuple_element_t>{})) + { + // if tuple does not have any slice then we can remove dimension + return Tuple<>{}; + } + else + { + // if tuple then recurrence + return make_tuple(GetShapeFromSlicedTensor(idx.At(num_i), shape.At(num_i))); + } + } + else if constexpr(is_detected>>::value) + { + // calculate new dimension + const auto& dim = size(shape.At(num_i)); + const auto val = idx.At(num_i).range(dim); + return make_tuple(val); + } + else + { + // remove dimension for just value + return Tuple<>{}; + } + }, + Number::Size()>{}); + // Remove empty tuples (deleted elements) and return + return UnrollNestedTuple<0, 1>(new_shape); + } + + template + constexpr auto GetStridesFromSlicedTensor(const Tuple& idx, + const StridesTmpType& strides) const + { + // Pack each value in tuple to remove empty tuples after generation + auto new_strides = generate_tuple( + [&](auto i) { + constexpr auto num_i = Number{}; + if constexpr(is_detected>>::value) + { + if constexpr(!IsSlicing(tuple_element_t>{})) + { + // if tuple does not have any slice then we can remove dimension + return Tuple<>{}; + } + else + { + // if tuple then recurrence + return make_tuple( + GetStridesFromSlicedTensor(idx.At(num_i), strides.At(num_i))); + } + } + else if constexpr(is_detected>>::value) + { + // Stride will be the same + return make_tuple(strides.At(num_i)); + } + else + { + // remove dimension for just value + return Tuple<>{}; + } + }, + Number::Size()>{}); + // Remove empty tuples (deleted elements) and return + return UnrollNestedTuple<0, 1>(new_strides); + } + + public: + using ElementSpaceSize = decltype(Layout{ + Shape{}, Strides{}}.GetElementSpaceSize()); // SpaceSize type for buffer + using TensorElementType = ElementType; // DataType + + static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace; + static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr || + BufferAddressSpace == MemoryTypeEnum ::Vgpr); + + __host__ __device__ Tensor() = delete; + __host__ __device__ Tensor(ElementType* pointer, const Layout& layout) + : layout_(layout), + buffer_(make_dynamic_buffer(pointer, layout.GetElementSpaceSize())) + { + } + + __host__ __device__ Tensor(const Layout& layout) : layout_(layout) + { + static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register."); + } + + __host__ __device__ constexpr const Layout& GetLayout() const + { + return layout_; + } + + // Getter for new sliced tensor + template {}), bool> = false> + __host__ __device__ auto operator[](const Tuple& idx) const + { + static_assert(IsDynamicBuffer, "Register slice is not supported"); + // Calculate offset based on first idx for new tensor + const index_t offset = layout_(GetStartIdxForSlicedTensor(idx)); + + auto new_shape = GetShapeFromSlicedTensor(idx, layout_.GetShape()); + if constexpr(is_same_v>) + { + auto new_layout = make_layout(new_shape); + return make_tensor(buffer_.p_data_ + offset, new_layout); + } + else + { + auto new_strides = GetStridesFromSlicedTensor(idx, layout_.GetStrides()); + auto new_layout = make_layout(new_shape, new_strides); + return make_tensor(buffer_.p_data_ + offset, new_layout); + } + } + + template {}), bool> = false> + __host__ __device__ auto operator()(const Tuple& idx) const + { + return this->operator[](idx); + } + + template {}), bool> = false> + __host__ __device__ auto operator()(Idxs... idxs) const + { + return this->operator[](make_tuple(idxs...)); + } + + // Getter for the const value + template {}), bool> = false> + __host__ __device__ const ElementType& operator[](const Tuple& idx) const + { + if constexpr(IsDynamicBuffer) + { + const index_t offset = layout_(idx); + return buffer_[offset]; + } + else + { + if constexpr(is_same_v>) + { + constexpr index_t offset = + Layout{Shape{}}.template operator()>(); + return buffer_[Number{}]; + } + else + { + constexpr index_t offset = + Layout{Shape{}, Strides{}}.template operator()>(); + return buffer_[Number{}]; + } + } + } + + template {}), bool> = false> + __host__ __device__ const ElementType& operator()(const Tuple& idx) const + { + return this->operator[](idx); + } + + template {}), bool> = false> + __host__ __device__ const ElementType& operator()(Idxs... idxs) const + { + return this->operator[](make_tuple(idxs...)); + } + + // Getter for the value reference + template {}), bool> = false> + __host__ __device__ ElementType& operator[](const Tuple& idx) + { + if constexpr(IsDynamicBuffer) + { + const index_t offset = layout_(idx); + return buffer_(offset); + } + else + { + if constexpr(is_same_v>) + { + constexpr index_t offset = + Layout{Shape{}}.template operator()>(); + return buffer_(Number{}); + } + else + { + constexpr index_t offset = + Layout{Shape{}, Strides{}}.template operator()>(); + return buffer_(Number{}); + } + } + } + + template {}), bool> = false> + __host__ __device__ ElementType& operator()(const Tuple& idx) + { + return this->operator[](idx); + } + + template {}), bool> = false> + __host__ __device__ ElementType& operator()(Idxs... idxs) + { + return this->operator[](make_tuple(idxs...)); + } + + __host__ __device__ constexpr auto GetDefaultDescriptor() + { + return layout_.GetDefaultDescriptor(); + } + + private: + using DynamicBufferType = DynamicBuffer; + using StaticBufferType = + StaticBufferTupleOfVector; + // If register use static buffer, else use dynamic buffer + using Buffer = std::conditional_t; + + const Layout layout_; + Buffer buffer_; +}; + +} // namespace wrapper +} // namespace ck diff --git a/include/ck/wrapper/layout_utils.hpp b/include/ck/wrapper/utils/layout_utils.hpp similarity index 86% rename from include/ck/wrapper/layout_utils.hpp rename to include/ck/wrapper/utils/layout_utils.hpp index fac8f33854..5df9dd7dea 100644 --- a/include/ck/wrapper/layout_utils.hpp +++ b/include/ck/wrapper/utils/layout_utils.hpp @@ -22,7 +22,7 @@ namespace wrapper { // Disable from doxygen docs generation /// @cond // forward declaration -template > +template struct Layout; template @@ -52,13 +52,23 @@ __host__ __device__ constexpr Layout make_layout(const Shape& sh * \return Constructed layout. */ template -__host__ __device__ constexpr Layout make_layout(const Shape& shape) +__host__ __device__ constexpr Layout> make_layout(const Shape& shape) { - return Layout(shape); + return Layout>(shape); } // Layout helpers // get +// Get dim (could be returned from get with empty Idxs) +/** + * \private + */ +template +__host__ __device__ T constexpr get(const T& dim) +{ + return dim; +} + /** * \brief Get element from tuple (Shape/Strides/Idxs). * @@ -82,7 +92,8 @@ __host__ __device__ constexpr auto get(const Tuple& tuple) template __host__ __device__ constexpr auto get(const Layout& layout) { - const auto new_shape = get(layout.GetShape()); + const auto& shape = layout.GetShape(); + const auto& new_shape = get(shape); static_assert(is_detected::value, "Shape of sub layout must be tuple"); if constexpr(is_same_v>) @@ -92,7 +103,8 @@ __host__ __device__ constexpr auto get(const Layout& layout) } else { - const auto new_strides = get(layout.GetStrides()); + const auto& strides = layout.GetStrides(); + const auto& new_strides = get(strides); static_assert(is_detected::value, "Strides of sub layout must be tuple"); return make_layout(new_shape, new_strides); @@ -113,11 +125,21 @@ __host__ __device__ constexpr auto get(const T& elem) } // size +// Get dim size (could be returned from get function) +/** + * \private + */ +template +__host__ __device__ T constexpr size(const T& dim) +{ + return dim; +} + /** * \brief Length get (product if tuple). * * \tparam idx Index to lookup. - * \param layout Layout to get Shape. + * \param layout Layout to get Shape of. * \return Requsted length. */ template @@ -140,16 +162,6 @@ __host__ __device__ constexpr index_t size(const Tuple& shape) unrolled_shape); } -// Get dim size (could be returned from get function) -/** - * \private - */ -template -__host__ __device__ T constexpr size(const T& dim) -{ - return dim; -} - /** * \brief Layout size (product of dims). * @@ -178,14 +190,15 @@ __host__ __device__ constexpr index_t size(const Tuple& tuple) /** * \brief Hierarchical size. * - * \tparam Idxs Indexes to lookup. + * \tparam Idx First index to lookup (to avoid empty Idxs). + * \tparam Idxs Next indexes to lookup. * \param elem Element to lookup. * \return Requsted element. */ -template +template __host__ __device__ constexpr auto size(const T& elem) { - return size(get(elem)); + return size(get(elem)); } // rank @@ -251,7 +264,8 @@ __host__ __device__ constexpr auto rank(const T& elem) template __host__ __device__ constexpr auto depth(const Layout& layout) { - return TupleDepth(layout.GetShape()); + const auto& shape = layout.GetShape(); + return TupleDepth(shape); } /** @@ -296,11 +310,11 @@ __host__ __device__ constexpr auto depth(const T& elem) /** * \brief Get Layout strides. * - * \param layout Layout to get strides. + * \param layout Layout to get strides from. * \return Requsted strides. */ template -__host__ __device__ constexpr auto stride(const Layout& layout) +__host__ __device__ constexpr const auto& stride(const Layout& layout) { return layout.GetStrides(); } @@ -308,11 +322,11 @@ __host__ __device__ constexpr auto stride(const Layout& layout) /** * \brief Get Layout shape. * - * \param layout Layout to get shape. + * \param layout Layout to get shape from. * \return Requsted shape. */ template -__host__ __device__ constexpr auto shape(const Layout& layout) +__host__ __device__ constexpr const auto& shape(const Layout& layout) { return layout.GetShape(); } diff --git a/include/ck/wrapper/utils/tensor_utils.hpp b/include/ck/wrapper/utils/tensor_utils.hpp new file mode 100644 index 0000000000..5f0dc3e500 --- /dev/null +++ b/include/ck/wrapper/utils/tensor_utils.hpp @@ -0,0 +1,290 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/tuple_helper.hpp" +#include "ck/utility/dynamic_buffer.hpp" +#include "ck/utility/amd_address_space.hpp" + +namespace ck { +namespace wrapper { + +/** + * \brief Memory type, allowed members: + * - Generic, + * - Global, + * - LDS, + * - SGPR, + * - VGPR, + */ +using MemoryTypeEnum = AddressSpaceEnum; + +// Disable from doxygen docs generation +/// @cond +// forward declarations +template +struct Layout; +template + +struct Tensor; + +template +struct Slice +{ + __host__ __device__ constexpr Slice() : from_(), to_() {} + __host__ __device__ constexpr Slice(FromType from, ToType to) : from_(from), to_(to) {} + + template + __host__ __device__ constexpr auto range(const T& dim) const + { + if constexpr(is_same_v || is_same_v || + is_same_v) + { + assert(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_) && "Invalid range"); + if(to_ < 0) + { + return dim - from_ + to_ + 1; + } + else + { + // workaround if one end of the interval is index_t and the second one is Number + return static_cast(to_) - static_cast(from_); + } + } + else + { + static_assert(dim >= to_ && from_ >= Number<0>{} && (to_ < 0 || to_ > from_), + "Invalid range"); + if constexpr(to_ < 0) + { + return dim - from_ + to_ + Number<1>{}; + } + else + { + return to_ - from_; + } + } + } + + __host__ __device__ static constexpr bool IsSlice() { return true; } + + const FromType from_; + const ToType to_; +}; + +template +using is_slice = decltype(std::declval().IsSlice()); + +template +using is_tuple = decltype(std::declval().IsTuple()); +/// @endcond + +/** + * \brief Make tensor function. + * + * \tparam MemoryType Type of memory. + * \param pointer Pointer to the memory. + * \param layout Tensor layout. + * \return Constructed tensor. + */ +template +constexpr auto make_tensor(ElementType* pointer, const Layout& layout) +{ + return Tensor( + pointer, layout); +} + +/** + * \brief Make SGPR or VGPR tensor function. + * + * \tparam MemoryType Type of memory. + * \tparam NumVectors Number of vectors. + * \tparam ScalarPerVector Scalars per vector. + * \tparam ElementType Memory data type. + * \param layout Tensor layout. + * \return Constructed tensor. + */ +template +constexpr auto make_register_tensor(const Layout& layout) +{ + static_assert(!IsNestedTuple(Shape{}), "Register tensor with nested layout is not supported"); + return Tensor(layout); +} + +/** + * \brief Get Tensor Layout. + * + * \param tensor Tensor to get layout of. + * \return Requsted layout. + */ +template +__host__ __device__ constexpr const auto& +layout(const Tensor& + tensor) +{ + return tensor.GetLayout(); +} + +/** + * \brief Product of tensor shape dims. + * + * \tparam Idxs Indexes to access specific shape dim (optional). + * \param tensor Tensor to get Shape of. + * \return Requsted size. + */ +template +__host__ __device__ constexpr index_t +size(const Tensor& + tensor) +{ + return size(tensor.GetLayout()); +} + +/** + * \brief Rank of Shape tuple. + * + * \tparam Idxs Indexes to access specific shape dim (optional). + * \param tensor Tensor to get rank of. + * \return Requsted rank. + */ +template +__host__ __device__ constexpr index_t +rank(const Tensor& + tensor) +{ + return rank(tensor.GetLayout()); +} + +/** + * \brief Depth of Shape tuple. + * + * \tparam Idxs Indexes to access specific shape dim (optional). + * \param tensor Tensor to get depth of. + * \return Requsted depth. + */ +template +__host__ __device__ constexpr index_t +depth(const Tensor& + tensor) +{ + return depth(tensor.GetLayout()); +} + +/** + * \brief Get Tensor strides. + * + * \param tensor Tensor to get strides from. + * \return Requsted strides. + */ +template +__host__ __device__ constexpr const auto& +stride(const Tensor& + tensor) +{ + return stride(tensor.GetLayout()); +} + +/** + * \brief Get Tensor shape. + * + * \param tensor Tensor to get shape from. + * \return Requsted shape. + */ +template +__host__ __device__ constexpr const auto& +shape(const Tensor& + tensor) +{ + return shape(tensor.GetLayout()); +} + +/** + * \brief Get dim slice. + * + * \param from Beginning of the interval. + * \param to End of the interval. (could be also negative to index from the end) + * \return Requested slice. Could be used to create sliced tensor from other tensor. + */ +template +constexpr auto slice(const FromType from, const ToType to) +{ + return Slice(from, to); +} + +/** + * \brief Get dim slice. (Assumed that from is equal to 1) + * + * \param to End of the interval. (could be also negative to index from the end) + * \return Requested slice. Could be used to create sliced tensor from other tensor. + */ +template +constexpr auto slice(const ToType to) +{ + if constexpr(is_same_v) + { + return Slice(0, to); + } + else + { + return Slice, ToType>(Number<0>{}, to); + } +} + +/** + * \brief Get whole dim slice (from = 0, to = -1). + * + * \return Requested slice. Could be used to create sliced tensor from other tensor. + */ +constexpr auto slice() { return Slice, Number<-1>>(Number<0>{}, Number<-1>{}); } + +} // namespace wrapper +} // namespace ck diff --git a/test/wrapper/CMakeLists.txt b/test/wrapper/CMakeLists.txt index e25ef176dd..6b25c08a8a 100644 --- a/test/wrapper/CMakeLists.txt +++ b/test/wrapper/CMakeLists.txt @@ -1,2 +1,4 @@ add_gtest_executable(test_layout test_layout.cpp) target_link_libraries(test_layout PRIVATE utility) +add_gtest_executable(test_tensor test_tensor.cpp) +target_link_libraries(test_tensor PRIVATE utility) diff --git a/test/wrapper/test_layout.cpp b/test/wrapper/test_layout.cpp index 7d09696fbb..14a8b96462 100644 --- a/test/wrapper/test_layout.cpp +++ b/test/wrapper/test_layout.cpp @@ -433,17 +433,17 @@ TEST(TestLayoutHelpers, ShapeAndStrides) ck::wrapper::make_layout(shape_compiletime, strides_compiletime); constexpr bool check_compiletime_shape = - std::is_same_v::type, - decltype(shape(layout_compiletime))>; + std::is_same_v>; constexpr bool check_compiletime_strides = - std::is_same_v::type, - decltype(stride(layout_compiletime))>; + std::is_same_v>; constexpr bool check_runtime_shape = - std::is_same_v::type, - decltype(shape(layout_runtime))>; + std::is_same_v>; constexpr bool check_runtime_strides = - std::is_same_v::type, - decltype(stride(layout_runtime))>; + std::is_same_v>; EXPECT_TRUE(check_compiletime_shape); EXPECT_TRUE(check_compiletime_strides); EXPECT_TRUE(check_runtime_shape); diff --git a/test/wrapper/test_tensor.cpp b/test/wrapper/test_tensor.cpp new file mode 100644 index 0000000000..92f8e2e1bd --- /dev/null +++ b/test/wrapper/test_tensor.cpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/library/utility/device_memory.hpp" + +#include "ck/host_utility/kernel_launch.hpp" + +#include "ck/utility/common_header.hpp" + +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" + +// Compare data in tensor with offset from layout. +// Data and offset should match if physical memory has been initialized with +// sequentially increasing values from 0. +template +__host__ __device__ bool TestTensorCheck3d(TensorType& tensor) +{ + const auto& layout = ck::wrapper::layout(tensor); + for(ck::index_t d = 0; d < ck::wrapper::size<0>(ck::wrapper::get<0>(layout)); d++) + { + for(ck::index_t h = 0; h < ck::wrapper::size<1>(ck::wrapper::get<0>(layout)); h++) + { + for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++) + { + const auto idx = ck::make_tuple(ck::make_tuple(d, h), w); + if(tensor(idx) != layout(idx)) + { + return false; + } + } + } + } + return true; +} + +template +__host__ __device__ bool TestTensorCheck1d(TensorType& tensor, ck::index_t start_offset = 0) +{ + const auto& layout = ck::wrapper::layout(tensor); + for(ck::index_t w = 0; w < ck::wrapper::size<0>(layout); w++) + { + if(tensor(w) - start_offset != layout(ck::make_tuple(w))) + { + return false; + } + } + return true; +} + +template +__host__ __device__ bool StaticTestTensorCheck1d(TensorType& tensor) +{ + const auto& layout = ck::wrapper::layout(tensor); + bool success = true; + ck::static_for<0, nelems, 1>{}([&](auto w) { + if(tensor(ck::Number{}) != layout(ck::make_tuple(w.value))) + { + success = false; + } + }); + return success; +} + +template +__host__ __device__ void InitTensor(TensorType& tensor) +{ + for(ck::index_t i = 0; i < ck::wrapper::size(ck::wrapper::layout(tensor)); i++) + { + tensor(i) = i; + } +} + +template +__host__ __device__ void StaticInitTensor(TensorType& tensor) +{ + + ck::static_for<0, nelems, 1>{}([&](auto i) { tensor(ck::Number{}) = i.value; }); +} + +// Tests +TEST(TestTensor, ReadWriteHostMemory) +{ + constexpr ck::index_t nelems = 8; + + std::array data; + const auto layout = ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(2, 2), 2)); + auto tensor = ck::wrapper::make_tensor(&data[0], layout); + InitTensor(tensor); + + EXPECT_TRUE(TestTensorCheck1d(tensor)); + EXPECT_TRUE(TestTensorCheck3d(tensor)); +} + +__global__ void TestTensorReadWriteDevice(void* data, void* success) +{ + constexpr ck::index_t nelems = 8; + constexpr ck::index_t scalar_per_vector = 1; + __shared__ ck::index_t p_shared[nelems]; + + ck::index_t* casted_data_ptr = static_cast(data); + bool* casted_success_ptr = static_cast(success); + + const auto layout = ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(2, 2), 2)); + constexpr auto register_layout = ck::wrapper::make_layout(ck::make_tuple(ck::Number<8>{})); + + auto tensor_global = + ck::wrapper::make_tensor(casted_data_ptr, layout); + auto tensor_lds = ck::wrapper::make_tensor(p_shared, layout); + auto tensor_vgpr = ck::wrapper::make_register_tensor(register_layout); + auto tensor_sgpr = ck::wrapper::make_register_tensor(register_layout); + + InitTensor(tensor_global); + InitTensor(tensor_lds); + StaticInitTensor(tensor_vgpr); + StaticInitTensor(tensor_sgpr); + + *casted_success_ptr &= TestTensorCheck1d(tensor_global); + *casted_success_ptr &= TestTensorCheck3d(tensor_global); + + *casted_success_ptr &= TestTensorCheck1d(tensor_lds); + *casted_success_ptr &= TestTensorCheck3d(tensor_lds); + + *casted_success_ptr &= StaticTestTensorCheck1d(tensor_vgpr); + + *casted_success_ptr &= StaticTestTensorCheck1d(tensor_sgpr); +} + +TEST(TestTensor, ReadWriteGlobalLdsRegistersMemory) +{ + constexpr ck::index_t nelems = 8; + std::array host_data; + + DeviceMem data_buf(nelems * sizeof(ck::index_t)); + data_buf.ToDevice(&host_data[0]); + DeviceMem success_buf(sizeof(bool)); + + launch_and_time_kernel(StreamConfig{}, + TestTensorReadWriteDevice, + dim3(1), + dim3(1), + nelems * sizeof(ck::index_t), + data_buf.GetDeviceBuffer(), + success_buf.GetDeviceBuffer()); + + bool success; + success_buf.FromDevice(&success); + EXPECT_TRUE(success); +} + +TEST(TestTensor, Slicing) +{ + constexpr ck::index_t nelems = 8; + + std::array data; + const auto shape = ck::make_tuple(ck::make_tuple(2, 2), 2); + const auto strides = ck::make_tuple(ck::make_tuple(1, 2), 4); + const auto layout = ck::wrapper::make_layout(shape, strides); + auto tensor = ck::wrapper::make_tensor(&data[0], layout); + InitTensor(tensor); + + auto tensor2x2x2 = + tensor(ck::make_tuple(ck::wrapper::slice(2), ck::wrapper::slice(2)), ck::wrapper::slice(2)); + EXPECT_EQ(ck::wrapper::rank(tensor2x2x2), 2); + EXPECT_EQ(ck::wrapper::depth(tensor2x2x2), 2); + EXPECT_EQ(ck::wrapper::size(tensor2x2x2), 8); + EXPECT_TRUE(TestTensorCheck1d(tensor2x2x2)); + + auto tensor2x2 = tensor(ck::make_tuple(1, ck::wrapper::slice(2)), ck::wrapper::slice(2)); + EXPECT_EQ(ck::wrapper::rank(tensor2x2), 2); + EXPECT_EQ(ck::wrapper::depth(tensor2x2), 2); + EXPECT_EQ(ck::wrapper::size(tensor2x2), 4); + EXPECT_TRUE(TestTensorCheck1d(tensor2x2, layout(ck::make_tuple(ck::make_tuple(1, 0), 0)))); + + auto tensor1x1 = tensor(ck::make_tuple(1, ck::wrapper::slice(1, 2)), ck::wrapper::slice(1, 2)); + EXPECT_EQ(rank(tensor1x1), 2); + EXPECT_EQ(depth(tensor1x1), 2); + EXPECT_EQ(size(tensor1x1), 1); + EXPECT_TRUE(TestTensorCheck1d(tensor1x1, layout(ck::make_tuple(ck::make_tuple(1, 1), 1)))); + + auto tensor2 = tensor(ck::make_tuple(1, 1), ck::wrapper::slice(0, 2)); + EXPECT_EQ(ck::wrapper::rank(tensor2), 1); + EXPECT_EQ(ck::wrapper::depth(tensor2), 1); + EXPECT_EQ(ck::wrapper::size(tensor2), 2); + EXPECT_TRUE(TestTensorCheck1d(tensor2, layout(ck::make_tuple(ck::make_tuple(1, 1), 0)))); + + // negative indexing + auto tensor1x2 = tensor(ck::make_tuple(1, ck::wrapper::slice(0, -2)), ck::wrapper::slice()); + EXPECT_EQ(rank(tensor1x2), 2); + EXPECT_EQ(depth(tensor1x2), 2); + EXPECT_EQ(size(tensor1x2), 2); + EXPECT_TRUE(TestTensorCheck1d(tensor1x2, layout(ck::make_tuple(ck::make_tuple(1, 0), 0)))); +} From 3246d1f693035929562240d8c73611345692bbbc Mon Sep 17 00:00:00 2001 From: abhimeda <138710508+abhimeda@users.noreply.github.com> Date: Fri, 15 Dec 2023 12:41:35 -0500 Subject: [PATCH 24/75] Adding Issue Template (#1094) * Add files via upload * fixed extra space typo * add mi300 GPU architectures and rocm versions 5.6.1 and 6.0.0 --------- Co-authored-by: illsilin Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- .github/ISSUE_TEMPLATE/config.yml | 1 + .github/ISSUE_TEMPLATE/issue_report.yml | 221 ++++++++++++++++++++++++ 2 files changed, 222 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/config.yml create mode 100644 .github/ISSUE_TEMPLATE/issue_report.yml diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000..0086358db1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: true diff --git a/.github/ISSUE_TEMPLATE/issue_report.yml b/.github/ISSUE_TEMPLATE/issue_report.yml new file mode 100644 index 0000000000..ef6e6faa1b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/issue_report.yml @@ -0,0 +1,221 @@ +name: Issue Report +description: File a report for ROCm related issues on Linux and Windows. For issues pertaining to documentation or non-bug related, please open a blank issue located below. +title: "[Issue]: " + +body: +- type: markdown + attributes: + value: | + Thank you for taking the time to fill out this report! + + You can acquire your OS, CPU, GPU (for filling out this report) with the following commands: + + Linux: + echo "OS:" && cat /etc/os-release | grep -E "^(NAME=|VERSION=)"; + echo "CPU: " && cat /proc/cpuinfo | grep "model name" | sort --unique; + echo "GPU:" && /opt/rocm/bin/rocminfo | grep -E "^\s*(Name|Marketing Name)"; + + Windows: + (Get-WmiObject Win32_OperatingSystem).Version + (Get-WmiObject win32_Processor).Name + (Get-WmiObject win32_VideoController).Name +- type: textarea + attributes: + label: Problem Description + description: Describe the issue you encountered. + validations: + required: true +- type: input + attributes: + label: Operating System + description: What is the name and version number of the OS? + placeholder: "e.g. Ubuntu 22.04.3 LTS (Jammy Jellyfish)" + validations: + required: true +- type: input + attributes: + label: CPU + description: What CPU did you encounter the issue on? + placeholder: "e.g. AMD Ryzen 9 5900HX with Radeon Graphics" + validations: + required: true +- type: dropdown + attributes: + label: GPU + description: What GPU(s) did you encounter the issue on (you can select multiple GPUs from the list) + multiple: true + options: + - AMD Instinct MI300X + - AMD Instinct MI300A + - AMD Instinct MI300 + - AMD Instinct MI250X + - AMD Instinct MI250 + - AMD Instinct MI210 + - AMD Instinct MI100 + - AMD Instinct MI50 + - AMD Instinct MI25 + - AMD Radeon Pro V620 + - AMD Radeon Pro VII + - AMD Radeon RX 7900 XTX + - AMD Radeon VII + - AMD Radeon Pro W7900 + - AMD Radeon Pro W7800 + - AMD Radeon Pro W6800 + - AMD Radeon Pro W6600 + - AMD Radeon Pro W5500 + - AMD Radeon RX 7900 XT + - AMD Radeon RX 7600 + - AMD Radeon RX 6950 XT + - AMD Radeon RX 6900 XT + - AMD Radeon RX 6800 XT + - AMD Radeon RX 6800 + - AMD Radeon RX 6750 + - AMD Radeon RX 6700 XT + - AMD Radeon RX 6700 + - AMD Radeon RX 6650 XT + - AMD Radeon RX 6600 XT + - AMD Radeon RX 6600 + - Other + validations: + required: true +- type: input + attributes: + label: Other + description: If you selected Other, please specify +- type: dropdown + attributes: + label: ROCm Version + description: What version(s) of ROCm did you encounter the issue on? + multiple: true + options: + - ROCm 6.0.0 + - ROCm 5.7.1 + - ROCm 5.7.0 + - ROCm 5.6.1 + - ROCm 5.6.0 + - ROCm 5.5.1 + - ROCm 5.5.0 + validations: + required: true +- type: dropdown + attributes: + label: ROCm Component + description: (Optional) If this issue relates to a specific ROCm component, it can be mentioned here. + multiple: true + options: + - Other + - AMD Common Language Runtime + - AMD MIGraphX + - AMD System Management Interface + - amdgpu KCL/autoconf + - amdgpu Kernel-mode GPU Driver + - amdgpu-install + - AOMP + - AOMP Extras + - AqlProfile + - build-infra + - chelsio + - clang-ocl + - Composable Kernel + - dkms + - docker / ROCm-docker + - flang + - gpuburn + - half + - HIP + - HIP Examples + - hipBLAS + - hipBLASLt + - HIPCC + - hipCUB + - hip-examples-private + - hipFFT + - hipfort + - HIPIFY + - hipRAND + - hipSOLVER + - hipSPARSE + - hipSPARSELt + - hipTensor + - hip-tests + - HSA Runtime + - infrastructure + - jenkins-utils + - libdrm + - Linux BPI packaging framework + - llvm-project + - Mesa + - meta + - MIOpen + - MIVisionX + - ml-framework-ci + - MLSEQA_TestRepo + - OpenCL API C++ Bindings + - OpenCL API Headers + - OpenCL Conformance Test Suite + - OpenCL ICD Loader + - perftest-p2p + - prototype + - RCCL + - rccl-rdma-sharp-plugins + - rocALUTION + - rocBLAS + - ROCdbgapi + - ROCdebug-agent + - rocFFT + - ROCgdb + - ROCK + - ROCm Documentation/Website + - ROCm Data Center Tool + - ROCm Examples + - ROCm for Windows + - ROCm Performance Primitives + - ROCm System Management Interface Library + - ROCm Thrust + - ROCm Validation Suite + - rocm_bandwidth_test + - rocm-cmake + - rocm-core + - rocm-docs-core + - rocminfo + - rocMLIR + - rocmtools + - rocPRIM + - rocprofiler + - rocRAND + - ROCR-Runtime + - rocSOLVER + - rocSPARSE + - roctracer + - ROCT-Thunk-Interface + - rocWMMA + - Tensile + - umr + - ibv_rc_pingpong-amd + - mellanox + - mpitest + - Pytorch + - Tensorflow + - APEX + - torchvision + - Magma +- type: textarea + attributes: + label: Steps to Reproduce + description: (Optional) Detailed steps to reproduce the issue. + validations: + required: false + +- type: textarea + attributes: + label: (Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support + description: The output of rocminfo --support could help to better address the problem. + validations: + required: false + +- type: textarea + attributes: + label: Additional Information + description: (Optional) Any additional information that is relevant, e.g. relevant environment variables, dockerfiles, log files, dmesg output (on Linux), etc. + validations: + required: false From dcedf3632f0e066c1712add65cb440622416363e Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Sat, 16 Dec 2023 09:17:40 -0800 Subject: [PATCH 25/75] Upgrade the default compiler to ROCm6.0 release. (#1103) * upgrade to rocm6.0 compiler * move rocm6.0 from private to public repo * switch to testing hipTensor mainline in CI --- Dockerfile | 6 +++--- Jenkinsfile | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/Dockerfile b/Dockerfile index 87b4eb8e2b..b9339ec5d4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ FROM ubuntu:20.04 ARG DEBIAN_FRONTEND=noninteractive -ARG ROCMVERSION=5.7 +ARG ROCMVERSION=6.0 ARG compiler_version="" ARG compiler_commit="" @@ -16,8 +16,8 @@ RUN apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg -RUN wget https://repo.radeon.com/amdgpu-install/5.7/ubuntu/focal/amdgpu-install_5.7.50700-1_all.deb --no-check-certificate -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_5.7.50700-1_all.deb +RUN wget https://repo.radeon.com/amdgpu-install/6.0/ubuntu/focal/amdgpu-install_6.0.60000-1_all.deb --no-check-certificate +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.0.60000-1_all.deb RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ diff --git a/Jenkinsfile b/Jenkinsfile index 8f661e4780..2bb48b85ce 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -33,7 +33,7 @@ def runShell(String command){ def getDockerImageName(){ def img - if (params.ROCMVERSION != "6.0"){ + if (params.ROCMVERSION != "6.1"){ if (params.COMPILER_VERSION == "") { img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" } @@ -655,8 +655,8 @@ def process_results(Map conf=[:]){ } //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION= - 0 21 * * * % ROCMVERSION=5.7;COMPILER_VERSION=;COMPILER_COMMIT= +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.0;COMPILER_VERSION= + 0 21 * * * % ROCMVERSION=6.0;COMPILER_VERSION=;COMPILER_COMMIT= 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=;USE_SCCACHE=false 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false''' : "" @@ -675,8 +675,8 @@ pipeline { description: "Force building docker image (default: false), set to true if docker image needs to be updated.") string( name: 'ROCMVERSION', - defaultValue: '5.7', - description: 'Specify which ROCM version to use: 5.7 (default).') + defaultValue: '6.0', + description: 'Specify which ROCM version to use: 6.0 (default).') string( name: 'COMPILER_VERSION', defaultValue: '', @@ -703,8 +703,8 @@ pipeline { description: "Use the CK build to verify hipTensor build and tests (default: ON)") string( name: 'hipTensor_branch', - defaultValue: 'develop', - description: 'Specify which branch of hipTensor to use (default: develop)') + defaultValue: 'mainline', + description: 'Specify which branch of hipTensor to use (default: mainline)') booleanParam( name: "USE_SCCACHE", defaultValue: true, From ad0a8e4cd27e8b66781d18ca6a7b1190e0611597 Mon Sep 17 00:00:00 2001 From: Bartlomiej Wroblewski Date: Mon, 18 Dec 2023 11:09:10 +0100 Subject: [PATCH 26/75] Optimize fp16 direct load GEMM instances (#1086) This PR optimizes fp16 instances of direct load GEMM kernel introduced in #999 and #1052. Measured the performance of new instances on CDNA2 GPU and compared it against the performance of the best non-direct-load GEMM instances. Used 76 different GEMM problems. On average, this change improves the performance of the tested problems by 47%. For cases known as latency-bound, the speedup is around 126%. --- ...ect_load_f16_f16_f16_mk_nk_mn_instance.cpp | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp index bb40237bf9..4c12e515e8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp @@ -34,22 +34,19 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances = // ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 64, 8, 8, 32, 32, 1, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, - - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 32, 128, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 32, 64, 8, 8, 32, 32, 1, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, - - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 64, 16, 16, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 64, 16, 16, 16, 16, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 32, 16, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 32, 16, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 128, 32, 32, 16, 16, 1, 1, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 128, 32, 32, 16, 16, 1, 1, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4> // clang-format on >; From a69aa2a11a83dc8e7b39be20aa50ec73d539dc28 Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 19 Dec 2023 04:23:11 +0800 Subject: [PATCH 27/75] layernorm and groupnorm backward data (#1083) * rename folder * Add type string * Remove typo * Add deviceOp to backward x * Add comment to describe the behavior of backward normalization * Add kernel function, prepare to implement * implement generic kernel * Check vector size * Add sweep once pipeline for small reduce size * Fix bug of KRaw_ error * Fix bug of dx stride * sanity check for mean and rstd * backward x for groupnorm * Add bwd x instance * add layernorm 2d bwd gamma beta instances * Change save mean var type from f32 to f16 in f16 mode * Change the example to f16 * Add groupnorm bwd gamma beta instance * Add groupnorm bwd x instance * Fix naming * Add layernorm bwd x ckprofiler * Add groupnorm bwd x profiler * clang format * Rename bwd x to bwd data * Fix bug of verification in profiler * Add test of layernorm and groupnorm bwd data * Add missing cmake * Add layernorm2d bwd data * rename fwd example * Add groupnorm client example * Fix typo. replace Invarient with Invariant * Add checking before running the best instance --- client_example/01_gemm/gemm.cpp | 1 + .../gemm_add_add_fastgelu.cpp | 1 + .../gemm_add_fastgelu.cpp | 1 + .../gemm_fastgelu.cpp | 1 + .../gemm_add_relu_add_layernorm_welford.cpp | 1 + client_example/05_layernorm/CMakeLists.txt | 3 + .../05_layernorm/layernorm2d_bwd_data.cpp | 170 ++++++ .../05_layernorm/layernorm2d_fwd.cpp | 3 +- .../05_layernorm/layernorm4d_fwd.cpp | 3 +- client_example/06_softmax/softmax4d.cpp | 1 + .../elementwise_layernorm2d.cpp | 1 + .../gemm_add_multiply.cpp | 1 + client_example/18_groupnorm/CMakeLists.txt | 7 +- .../18_groupnorm/groupnorm_bwd_data.cpp | 182 ++++++ ...norm_swish.cpp => groupnorm_swish_fwd.cpp} | 0 .../20_splitk_gemm/splitK_gemm_fp16_f8.cpp | 1 + .../elementwise_transpose_3d.cpp | 1 + example/53_layernorm2d_bwd/CMakeLists.txt | 1 + .../layernorm2d_bwd_fp32.cpp} | 97 ++- example/53_layernorm_bwd/CMakeLists.txt | 1 - example/54_groupnorm_bwd/CMakeLists.txt | 2 +- ...rm_bwd_fp16.cpp => groupnorm_bwd_fp32.cpp} | 101 +++- .../device/device_normalization_bwd_data.hpp | 59 ++ .../device_normalization_bwd_data_impl.hpp | 465 +++++++++++++++ ...vice_normalization_bwd_gamma_beta_impl.hpp | 32 +- .../impl/device_normalization_fwd_impl.hpp | 6 +- .../device_normalization_fwd_splitk_impl.hpp | 4 +- .../gridwise_normalization_bwd_data.hpp | 554 ++++++++++++++++++ .../gridwise_normalization_bwd_gamma_beta.hpp | 11 +- .../cpu/reference_groupnorm_bwd.hpp | 25 + .../cpu/reference_layernorm_bwd.hpp | 24 + .../gpu/groupnorm_bwd_data.hpp | 64 ++ .../gpu/layernorm_bwd_data.hpp | 84 +++ .../gpu/normalization_fwd.hpp | 8 +- .../gpu/normalization_fwd_swish.hpp | 4 +- .../gpu/normalization_bwd_data/CMakeLists.txt | 8 + ...device_groupnorm_bwd_data_f32_instance.cpp | 22 + ...vice_layernorm2d_bwd_data_f16_instance.cpp | 23 + ...vice_layernorm2d_bwd_data_f32_instance.cpp | 23 + ...normalization_bwd_data_instance_common.hpp | 73 +++ .../CMakeLists.txt | 8 + ..._groupnorm_bwd_gamma_beta_f32_instance.cpp | 23 + ...ayernorm2d_bwd_gamma_beta_f16_instance.cpp | 24 + ...ayernorm2d_bwd_gamma_beta_f32_instance.cpp | 24 + ...ization_bwd_gamma_beta_instance_common.hpp | 73 +++ .../device_groupnorm_fwd_f16_instance.cpp | 2 +- ...evice_groupnorm_fwd_swish_f16_instance.cpp | 2 +- .../device_layernorm2d_fwd_f16_instance.cpp | 2 +- .../device_layernorm4d_fwd_f16_instance.cpp | 2 +- .../normalization_fwd_instance_common.hpp | 74 +-- .../profile_groupnorm_bwd_data_impl.hpp | 250 ++++++++ .../profile_layernorm_bwd_data_impl.hpp | 255 ++++++++ profiler/src/CMakeLists.txt | 3 + profiler/src/profile_groupnorm_bwd_data.cpp | 104 ++++ profiler/src/profile_groupnorm_fwd.cpp | 2 +- profiler/src/profile_layernorm_bwd_data.cpp | 112 ++++ profiler/src/profile_layernorm_fwd.cpp | 4 +- test/CMakeLists.txt | 1 + test/normalization_bwd_data/CMakeLists.txt | 13 + .../test_groupnorm_bwd_data_fp32.cpp | 51 ++ .../test_layernorm2d_bwd_data_fp32.cpp | 48 ++ .../test_groupnorm_fwd_fp16.cpp | 4 +- .../test_groupnorm_fwd_fp32.cpp | 2 +- .../test_layernorm2d_fwd_fp16.cpp | 4 +- .../test_layernorm4d_fwd_fp16.cpp | 4 +- 65 files changed, 3050 insertions(+), 110 deletions(-) create mode 100644 client_example/05_layernorm/layernorm2d_bwd_data.cpp create mode 100644 client_example/18_groupnorm/groupnorm_bwd_data.cpp rename client_example/18_groupnorm/{groupnorm_swish.cpp => groupnorm_swish_fwd.cpp} (100%) create mode 100644 example/53_layernorm2d_bwd/CMakeLists.txt rename example/{53_layernorm_bwd/layernorm2d_bwd_fp16.cpp => 53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp} (62%) delete mode 100644 example/53_layernorm_bwd/CMakeLists.txt rename example/54_groupnorm_bwd/{groupnorm_bwd_fp16.cpp => groupnorm_bwd_fp32.cpp} (62%) create mode 100644 include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_data.hpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_data/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_groupnorm_bwd_data_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_data/normalization_bwd_data_instance_common.hpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_groupnorm_bwd_gamma_beta_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/normalization_bwd_gamma_beta_instance_common.hpp create mode 100644 profiler/include/profiler/profile_groupnorm_bwd_data_impl.hpp create mode 100644 profiler/include/profiler/profile_layernorm_bwd_data_impl.hpp create mode 100644 profiler/src/profile_groupnorm_bwd_data.cpp create mode 100644 profiler/src/profile_layernorm_bwd_data.cpp create mode 100644 test/normalization_bwd_data/CMakeLists.txt create mode 100644 test/normalization_bwd_data/test_groupnorm_bwd_data_fp32.cpp create mode 100644 test/normalization_bwd_data/test_layernorm2d_bwd_data_fp32.cpp diff --git a/client_example/01_gemm/gemm.cpp b/client_example/01_gemm/gemm.cpp index c37f208db1..11f9222873 100644 --- a/client_example/01_gemm/gemm.cpp +++ b/client_example/01_gemm/gemm.cpp @@ -185,6 +185,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp index 756889562e..e845c120d8 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp @@ -204,6 +204,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp index 8d2a8c234a..e77b67c905 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp @@ -197,6 +197,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp index c02df018fd..7648da9cac 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp @@ -190,6 +190,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp index 3d5fb60048..93f8847c62 100644 --- a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp +++ b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp @@ -200,6 +200,7 @@ int main(int argc, char* argv[]) << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/05_layernorm/CMakeLists.txt b/client_example/05_layernorm/CMakeLists.txt index 9cbfc2b763..246f877cde 100644 --- a/client_example/05_layernorm/CMakeLists.txt +++ b/client_example/05_layernorm/CMakeLists.txt @@ -1,3 +1,6 @@ +add_executable(client_layernorm2d_bwd_data layernorm2d_bwd_data.cpp) +target_link_libraries(client_layernorm2d_bwd_data PRIVATE composable_kernel::device_other_operations) + add_executable(client_layernorm2d_fwd layernorm2d_fwd.cpp) target_link_libraries(client_layernorm2d_fwd PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/05_layernorm/layernorm2d_bwd_data.cpp b/client_example/05_layernorm/layernorm2d_bwd_data.cpp new file mode 100644 index 0000000000..9f26cb6840 --- /dev/null +++ b/client_example/05_layernorm/layernorm2d_bwd_data.cpp @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_data.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DXDataType = float; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * M * N); + SimpleDeviceMem x_dev(sizeof(XDataType) * M * N); + SimpleDeviceMem gamma_dev(sizeof(GammaDataType) * N); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem dx_dev(sizeof(DXDataType) * M * N); + + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationBwdData; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {0, 1}, // gammaStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N, 1}, // dxStrides + {1}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = sizeof(DYDataType) * M * N + sizeof(XDataType) * M * N + + sizeof(GammaDataType) * N + sizeof(MeanInvStdDataType) * M * 2 + + sizeof(DXDataType) * M * N; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {0, 1}, // gammaStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N, 1}, // dxStrides + {1}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/05_layernorm/layernorm2d_fwd.cpp b/client_example/05_layernorm/layernorm2d_fwd.cpp index 19ddd614de..420225b613 100644 --- a/client_example/05_layernorm/layernorm2d_fwd.cpp +++ b/client_example/05_layernorm/layernorm2d_fwd.cpp @@ -16,7 +16,7 @@ using XDataType = ck::half_t; using GammaDataType = ck::half_t; using BetaDataType = ck::half_t; using YDataType = ck::half_t; -using SaveMeanInvStdDataType = float; +using SaveMeanInvStdDataType = ck::half_t; using PassThrough = ck::tensor_operation::element_wise::PassThrough; #define SAVE_MEAN_INV_STD @@ -150,6 +150,7 @@ int main(int argc, char* argv[]) << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() diff --git a/client_example/05_layernorm/layernorm4d_fwd.cpp b/client_example/05_layernorm/layernorm4d_fwd.cpp index 9a7ecfd87e..fa408dc751 100644 --- a/client_example/05_layernorm/layernorm4d_fwd.cpp +++ b/client_example/05_layernorm/layernorm4d_fwd.cpp @@ -16,7 +16,7 @@ using XDataType = ck::half_t; using GammaDataType = ck::half_t; using BetaDataType = ck::half_t; using YDataType = ck::half_t; -using SaveMeanInvStdDataType = float; +using SaveMeanInvStdDataType = ck::half_t; using PassThrough = ck::tensor_operation::element_wise::PassThrough; #define SAVE_MEAN_INV_STD @@ -155,6 +155,7 @@ int main(int argc, char* argv[]) << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() diff --git a/client_example/06_softmax/softmax4d.cpp b/client_example/06_softmax/softmax4d.cpp index 2ccad27a88..a62af76635 100644 --- a/client_example/06_softmax/softmax4d.cpp +++ b/client_example/06_softmax/softmax4d.cpp @@ -140,6 +140,7 @@ int main(int argc, char* argv[]) << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() diff --git a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp index bc4a6fe0bf..8326f0758c 100644 --- a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp +++ b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp @@ -142,6 +142,7 @@ int main() << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() diff --git a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp index c74d7c6bd8..cde4713b23 100644 --- a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp +++ b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp @@ -204,6 +204,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/18_groupnorm/CMakeLists.txt b/client_example/18_groupnorm/CMakeLists.txt index dee85f9a60..deb50f6fce 100644 --- a/client_example/18_groupnorm/CMakeLists.txt +++ b/client_example/18_groupnorm/CMakeLists.txt @@ -1,2 +1,5 @@ -add_executable(client_groupnorm_swish groupnorm_swish.cpp) -target_link_libraries(client_groupnorm_swish PRIVATE composable_kernel::device_other_operations) +add_executable(client_groupnorm_bwd_data groupnorm_bwd_data.cpp) +target_link_libraries(client_groupnorm_bwd_data PRIVATE composable_kernel::device_other_operations) + +add_executable(client_groupnorm_swish_fwd groupnorm_swish_fwd.cpp) +target_link_libraries(client_groupnorm_swish_fwd PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/18_groupnorm/groupnorm_bwd_data.cpp b/client_example/18_groupnorm/groupnorm_bwd_data.cpp new file mode 100644 index 0000000000..01ca21ba57 --- /dev/null +++ b/client_example/18_groupnorm/groupnorm_bwd_data.cpp @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DXDataType = float; + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 32; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t G = 64; + ck::index_t C = 128; + + std::size_t length = N * H * W * G * C; + + std::vector strideDy = {H * W * G * C, W * G * C, G * C, C, 1}; + std::vector strideX = strideDy; + std::vector strideDx = strideDy; + + std::vector strideGamma = {0, 0, 0, C, 1}; + std::vector strideMeanInvStd = {G, 0, 0, 1, 0}; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * length); + SimpleDeviceMem x_dev(sizeof(XDataType) * length); + SimpleDeviceMem gamma_dev(sizeof(GammaDataType) * G * C); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem dx_dev(sizeof(DXDataType) * length); + + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationBwdData; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideGamma, + strideMeanInvStd, + strideMeanInvStd, + strideDx, + {1, 2, 4}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = sizeof(DYDataType) * length + sizeof(XDataType) * length + + sizeof(GammaDataType) * G * C + + sizeof(MeanInvStdDataType) * N * G * 2 + + sizeof(DXDataType) * length; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideGamma, + strideMeanInvStd, + strideMeanInvStd, + strideDx, + {1, 2, 4}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/18_groupnorm/groupnorm_swish.cpp b/client_example/18_groupnorm/groupnorm_swish_fwd.cpp similarity index 100% rename from client_example/18_groupnorm/groupnorm_swish.cpp rename to client_example/18_groupnorm/groupnorm_swish_fwd.cpp diff --git a/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp b/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp index 94a57cd029..a740c22f91 100644 --- a/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp +++ b/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp @@ -191,6 +191,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp b/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp index fb63e20147..65ba46fcd2 100644 --- a/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp +++ b/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp @@ -117,6 +117,7 @@ int main() << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() diff --git a/example/53_layernorm2d_bwd/CMakeLists.txt b/example/53_layernorm2d_bwd/CMakeLists.txt new file mode 100644 index 0000000000..a58b1109f7 --- /dev/null +++ b/example/53_layernorm2d_bwd/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_layernorm2d_bwd_fp32 layernorm2d_bwd_fp32.cpp) diff --git a/example/53_layernorm_bwd/layernorm2d_bwd_fp16.cpp b/example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp similarity index 62% rename from example/53_layernorm_bwd/layernorm2d_bwd_fp16.cpp rename to example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp index f2e6bfb44d..0b0a3e72ad 100644 --- a/example/53_layernorm_bwd/layernorm2d_bwd_fp16.cpp +++ b/example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp @@ -15,16 +15,17 @@ #include "ck/library/utility/literals.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp" -using DYDataType = ck::half_t; -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; using MeanInvStdDataType = float; -using DGammaDataType = ck::half_t; -using DBetaDataType = ck::half_t; -using DXDataType = ck::half_t; +using DGammaDataType = float; +using DBetaDataType = float; +using DXDataType = float; using ComputeDataType = float; constexpr int Rank = 2; @@ -39,6 +40,7 @@ constexpr int NumReduceDim = 1; // inv_std: [M, 1] // Output shape +// dx: [M, N] // dgamma: [1, N] // dbeta: [1, N] @@ -46,8 +48,34 @@ constexpr int NumReduceDim = 1; // dbeta = reduce_sum(dy, axis=0) // [CAUSION] -// In DeviceNormalizationBwdGammaBetaImpl, M is invarient dimension, K is reduced dimension -// Hence, M in this example and DeviceNormalizationBwdGammaBetaImpl is different +// In DeviceNormalizationBwdDataImpl & DeviceNormalizationBwdGammaBetaImpl, M is Invariant +// dimension, K is reduced dimension Hence, M in this example and +// DeviceNormalizationBwdGammaBetaImpl is different +using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl< + DYDataType, + XDataType, + GammaDataType, + MeanInvStdDataType, + ComputeDataType, + DXDataType, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // MThreadClusterSize + 32, // KThreadClusterSize + 1, // MThreadSliceSize + 4, // KThreadSliceSize + true, // IsDYFastestDimReduced + 4, // DYSrcVectorSize + true, // IsXFastestDimReduced + 4, // XSrcVectorSize + true, // IsGammaFastestDimReduced + 4, // GammaSrcVectorSize + false, // IsMeanInvStdFastestDimReduced + 1, // MeanInvStdSrcVectorSize + true, // IsDXFastestDimReduced + 4>; // DXDstVectorSize + using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl< DYDataType, XDataType, @@ -58,18 +86,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio Rank, NumReduceDim, 256, // BlockSize - 8, // ClusterInvarient - 32, // ClusterReduce - 8, // SliceInvarient - 1, // SliceReduce + 8, // MThreadClusterSize + 32, // KThreadClusterSize + 4, // MThreadSliceSize + 1, // KThreadSliceSize false, // IsDYFastestDimReduced - 8, // DYSrcVectorSize + 4, // DYSrcVectorSize false, // IsXFastestDimReduced - 8, // XSrcVectorSize + 4, // XSrcVectorSize true, // IsMeanInvStdFastestDimReduced 1, // MeanInvStdSrcVectorSize - 1, // DGammaDstVectorSize - 1>; // DBetaDstVectorSize + 4, // DGammaDstVectorSize + 4>; // DBetaDstVectorSize int main() { @@ -96,16 +124,48 @@ int main() DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize()); DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); dy_dev.ToDevice(dy.mData.data()); x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); mean_dev.ToDevice(mean.mData.data()); inv_std_dev.ToDevice(inv_std.mData.data()); + // backward x + auto x_device_instance = XDeviceInstance{}; + + auto x_argument_ptr = x_device_instance.MakeArgumentPointer({M, N}, // lengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {0, 1}, // gammaStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N, 1}, // dxStrides + {1}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; + return 1; + }; + + auto x_invoker_ptr = x_device_instance.MakeInvokerPointer(); + x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + // backward gamma & beta auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; auto gamma_beta_argument_ptr = gamma_beta_device_instance.MakeArgumentPointer({M, N}, // inLengths @@ -126,7 +186,8 @@ int main() if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) { - std::cout << "The runtime parameters are not supported" << std::endl; + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; return 1; }; @@ -156,9 +217,11 @@ int main() dgamma_dev.FromDevice(dgamma.mData.data()); dbeta_dev.FromDevice(dbeta.mData.data()); + dx_dev.FromDevice(dx.mData.data()); pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); + pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3); } return (pass ? 0 : 1); diff --git a/example/53_layernorm_bwd/CMakeLists.txt b/example/53_layernorm_bwd/CMakeLists.txt deleted file mode 100644 index 24db221523..0000000000 --- a/example/53_layernorm_bwd/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_example_executable(example_layernorm2d_bwd_fp16 layernorm2d_bwd_fp16.cpp) diff --git a/example/54_groupnorm_bwd/CMakeLists.txt b/example/54_groupnorm_bwd/CMakeLists.txt index ac548cbc79..2cb103499c 100644 --- a/example/54_groupnorm_bwd/CMakeLists.txt +++ b/example/54_groupnorm_bwd/CMakeLists.txt @@ -1 +1 @@ -add_example_executable(example_groupnorm_bwd_fp16 groupnorm_bwd_fp16.cpp) +add_example_executable(example_groupnorm_bwd_fp32 groupnorm_bwd_fp32.cpp) diff --git a/example/54_groupnorm_bwd/groupnorm_bwd_fp16.cpp b/example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp similarity index 62% rename from example/54_groupnorm_bwd/groupnorm_bwd_fp16.cpp rename to example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp index 1537a014d4..6cf1b2ff91 100644 --- a/example/54_groupnorm_bwd/groupnorm_bwd_fp16.cpp +++ b/example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp @@ -15,23 +15,58 @@ #include "ck/library/utility/literals.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp" -using DYDataType = ck::half_t; -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; using MeanInvStdDataType = float; -using DGammaDataType = ck::half_t; -using DBetaDataType = ck::half_t; -using DXDataType = ck::half_t; +using DGammaDataType = float; +using DBetaDataType = float; +using DXDataType = float; using ComputeDataType = float; constexpr int Rank = 5; constexpr int NumReduceDim = 3; // Grouprnorm -// kernel: M , K +// kernel 1: M , K +// dy: N, H, W, G, C -> N * G, H * W * C +// x: N, H, W, G, C -> N * G, H * W * C +// gamma: 1, 1, 1, G, C -> 1 * G, 1 * 1 * C +// mean: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1 +// rstd: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1 + +// dx: N, H, W, G, C -> N * G, H * W * C + +using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl< + DYDataType, + XDataType, + GammaDataType, + MeanInvStdDataType, + ComputeDataType, + DXDataType, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // MThreadClusterSize + 32, // KThreadClusterSize + 1, // MThreadSliceSize + 4, // KThreadSliceSize + true, // IsDYFastestDimReduced + 4, // DYSrcVectorSize + true, // IsXFastestDimReduced + 4, // XSrcVectorSize + true, // IsGammaFastestDimReduced + 4, // GammaSrcVectorSize + false, // IsMeanInvStdFastestDimReduced + 1, // MeanInvStdSrcVectorSize + true, // IsDXFastestDimReduced + 4>; // DXDstVectorSize + +// kernel 2: M , K // dy: N, H, W, G, C -> G * C, N * H * W // x: N, H, W, G, C -> G * C, N * H * W // mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1 @@ -52,18 +87,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio Rank, NumReduceDim, 256, // BlockSize - 8, // ClusterInvarient + 8, // ClusterInvariant 32, // ClusterReduce - 8, // SliceInvarient + 4, // SliceInvariant 1, // SliceReduce false, // IsDYFastestDimReduced - 8, // DYSrcVectorSize + 4, // DYSrcVectorSize false, // IsXFastestDimReduced - 8, // XSrcVectorSize + 4, // XSrcVectorSize false, // IsMeanInvStdFastestDimReduced 1, // MeanInvStdSrcVectorSize - 1, // DGammaDstVectorSize - 1>; // DBetaDstVectorSize + 4, // DGammaDstVectorSize + 4>; // DBetaDstVectorSize int main() { @@ -93,20 +128,55 @@ int main() DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize()); DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); dy_dev.ToDevice(dy.mData.data()); x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); mean_dev.ToDevice(mean.mData.data()); inv_std_dev.ToDevice(inv_std.mData.data()); std::vector dyStrides{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; std::vector xStrides{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}; + std::vector gammaStrides = {0, 0, 0, C, 1}; std::vector meanStrides = {G, 0, 0, 1, 0}; std::vector invStdStrides = {G, 0, 0, 1, 0}; + std::vector dxStrides{dx.mDesc.GetStrides().begin(), dx.mDesc.GetStrides().end()}; + + // backward x + auto x_device_instance = XDeviceInstance{}; + + auto x_argument_ptr = x_device_instance.MakeArgumentPointer({N, H, W, G, C}, // lengths + dyStrides, // dyStrides + xStrides, // xStrides + gammaStrides, // gammaStrides + meanStrides, // meanStrides + invStdStrides, // invStdStrides + dxStrides, // dxStrides + {1, 2, 4}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; + return 1; + }; + + auto x_invoker_ptr = x_device_instance.MakeInvokerPointer(); + x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + // backward gamma & beta auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; auto gamma_beta_argument_ptr = @@ -128,7 +198,8 @@ int main() if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) { - std::cout << "The runtime parameters are not supported" << std::endl; + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; return 1; }; @@ -158,9 +229,11 @@ int main() dgamma_dev.FromDevice(dgamma.mData.data()); dbeta_dev.FromDevice(dbeta.mData.data()); + dx_dev.FromDevice(dx.mData.data()); pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); + pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3); } return (pass ? 0 : 1); diff --git a/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp b/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp new file mode 100644 index 0000000000..327acfeb53 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct DeviceNormalizationBwdData : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector dxStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_gamma, + const void* p_mean, + const void* p_invStd, + void* p_dx) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceNormalizationBwdDataPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp new file mode 100644 index 0000000000..86689af0b7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp @@ -0,0 +1,465 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +// M is Invariant dimension, K is reduced dimension +namespace ck { +namespace tensor_operation { +namespace device { +template +__global__ void +kernel_normalization_bwd_data(const GridDesc_M_K dy_grid_desc_m_k, + const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K gamma_grid_desc_m_k, + const GridDesc_M_K mean_grid_desc_m_k, + const GridDesc_M_K inv_std_grid_desc_m_k, + const GridDesc_M_K dx_grid_desc_m_k, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DXDataType* const __restrict__ p_dx_global) +{ + GridwiseNormalizationBwd::Run(dy_grid_desc_m_k, + x_grid_desc_m_k, + gamma_grid_desc_m_k, + mean_grid_desc_m_k, + inv_std_grid_desc_m_k, + dx_grid_desc_m_k, + num_k_block_tile_iteration, + p_dy_global, + p_x_global, + p_gamma_global, + p_mean_global, + p_inv_std_global, + p_dx_global); +}; + +template +struct DeviceNormalizationBwdDataImpl : public DeviceNormalizationBwdData +{ + static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0; + static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0; + static constexpr index_t GammaSrcVectorDim = IsGammaFastestDimReduced ? 1 : 0; + static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0; + static constexpr index_t DXDstVectorDim = IsDxFastestDimReduced ? 1 : 0; + + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) || + (DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || + (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0), + "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please " + "check!"); + + static_assert(((DXDstVectorDim == 0 && MThreadSliceSize % DXDstVectorSize == 0) || + (DXDstVectorDim == 1 && KThreadSliceSize % DXDstVectorSize == 0)), + "Invalid thread slice sizes and/or dx vector sizes configuration, please check!"); + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); + + static auto Make2dDescriptor(const std::vector& lengths, + const std::vector& strides, + int numBlockTileIteration) + { + const auto tupleLengths = make_tuple_from_array(lengths, Number{}); + const auto tupleStrides = make_tuple_from_array(strides, Number{}); + + const auto desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides); + + const auto grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(lengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(lengths, InvariantDims{}); + + return transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{}); + + const auto pad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto pad_K = K_BlockTileSize * numBlockTileIteration - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, pad_M), + make_right_pad_transform(reduceLength, pad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k_padded; + } + + using GridDesc_M_K = decltype(Make2dDescriptor({1}, {1}, 1)); + + using GridwiseNormalizationBwdDataGeneric = + GridwiseNormalizationBwdData_mk_to_mk; + + using GridwiseNormalizationBwdDataSweepOnce = + GridwiseNormalizationBwdData_mk_to_mk; + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector dxStrides, + const std::vector reduceDims, + const DYDataType* p_dy, + const XDataType* p_x, + const GammaDataType* p_gamma, + const MeanInvStdDataType* p_mean, + const MeanInvStdDataType* p_invStd, + DXDataType* p_dx) + : p_dy_(p_dy), + p_x_(p_x), + p_gamma_(p_gamma), + p_mean_(p_mean), + p_invStd_(p_invStd), + p_dx_(p_dx) + { + lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + dyStrides_ = shuffle_tensor_dimensions(dyStrides, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); + meanStrides_ = shuffle_tensor_dimensions(meanStrides, reduceDims); + invStdStrides_ = + shuffle_tensor_dimensions(invStdStrides, reduceDims); + dxStrides_ = shuffle_tensor_dimensions(dxStrides, reduceDims); + + std::tie(MRaw_, KRaw_) = get_2d_lengths(lengths_); + + numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize); + + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize); + + dy_grid_desc_m_k_ = Make2dDescriptor(lengths_, dyStrides_, numBlockTileIteration_); + x_grid_desc_m_k_ = Make2dDescriptor(lengths_, xStrides_, numBlockTileIteration_); + gamma_grid_desc_m_k_ = + Make2dDescriptor(lengths_, gammaStrides_, numBlockTileIteration_); + mean_grid_desc_m_k_ = Make2dDescriptor(lengths_, meanStrides_, numBlockTileIteration_); + inv_std_grid_desc_m_k_ = + Make2dDescriptor(lengths_, invStdStrides_, numBlockTileIteration_); + dx_grid_desc_m_k_ = Make2dDescriptor(lengths_, dxStrides_, numBlockTileIteration_); + + isSweeponce_ = dy_grid_desc_m_k_.GetLength(Number<1>{}) <= K_BlockTileSize; + } + + const DYDataType* p_dy_; + const XDataType* p_x_; + const GammaDataType* p_gamma_; + const MeanInvStdDataType* p_mean_; + const MeanInvStdDataType* p_invStd_; + DXDataType* p_dx_; + + std::vector lengths_; + std::vector dyStrides_; + std::vector xStrides_; + std::vector gammaStrides_; + std::vector meanStrides_; + std::vector invStdStrides_; + std::vector dxStrides_; + + int numBlockTileIteration_; + size_t gridSize_; + + // tensor descriptor + GridDesc_M_K dy_grid_desc_m_k_; + GridDesc_M_K x_grid_desc_m_k_; + GridDesc_M_K gamma_grid_desc_m_k_; + GridDesc_M_K mean_grid_desc_m_k_; + GridDesc_M_K inv_std_grid_desc_m_k_; + GridDesc_M_K dx_grid_desc_m_k_; + + bool isSweeponce_; + index_t MRaw_; // Invariant length + index_t KRaw_; // reduce length + }; + + struct Invoker : public BaseInvoker + { + auto KernelSelector(bool isSweepOnce) + { + return isSweepOnce + ? kernel_normalization_bwd_data + : kernel_normalization_bwd_data; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel_main = KernelSelector(arg.isSweeponce_); + + return launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.dy_grid_desc_m_k_, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.mean_grid_desc_m_k_, + arg.inv_std_grid_desc_m_k_, + arg.dx_grid_desc_m_k_, + arg.numBlockTileIteration_, + arg.p_dy_, + arg.p_x_, + arg.p_gamma_, + arg.p_mean_, + arg.p_invStd_, + arg.p_dx_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + template + bool IsVectorDimSizeValid(const std::vector& lengths, + const std::vector& strides) + { + if constexpr(SrcVectorSize == 1) + return true; + + // Fastest dimension is not reduced + if constexpr(SrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + return false; + + if(strides[NumInvariantDim - 1] != 1) + return false; + + if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0) + return false; + } + else // Fastest dimension is reduced + { + if(strides[Rank - 1] != 1) + return false; + + if(lengths[Rank - 1] % SrcVectorSize != 0) + return false; + }; + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + bool pass = true; + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->dyStrides_); + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->xStrides_); + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->gammaStrides_); + pass &= IsVectorDimSizeValid( + p_arg_->lengths_, p_arg_->meanStrides_); + pass &= IsVectorDimSizeValid( + p_arg_->lengths_, p_arg_->invStdStrides_); + + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->dxStrides_); + return pass; + } + + std::unique_ptr MakeArgumentPointer(const std::vector lengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector dxStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_gamma, + const void* p_mean, + const void* p_invStd, + void* p_dx) override + { + if(lengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank || + gammaStrides.size() != Rank || meanStrides.size() != Rank || + invStdStrides.size() != Rank || dxStrides.size() != Rank) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(lengths, + dyStrides, + xStrides, + gammaStrides, + meanStrides, + invStdStrides, + dxStrides, + reduceDims, + static_cast(p_dy), + static_cast(p_x), + static_cast(p_gamma), + static_cast(p_mean), + static_cast(p_invStd), + static_cast(p_dx)); + } + + virtual std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationBwdDataImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "DYSrcVectorSize" << DYSrcVectorSize << "_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_MeanRstd" << MeanInvStdSrcVectorSize << "_Dx" << DXDstVectorSize; + str << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp index 43d2db4624..c35652bfe8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp @@ -14,7 +14,7 @@ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" -// M is invarient dimension, K is reduced dimension +// M is Invariant dimension, K is reduced dimension namespace ck { namespace tensor_operation { namespace device { @@ -87,7 +87,6 @@ struct DeviceNormalizationBwdGammaBetaImpl Rank, NumReduceDim> { - static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0; static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0; static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0; @@ -102,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)), "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); - static_assert( - ((MThreadSliceSize % DGammaDstVectorSize == 0) || - (MThreadSliceSize % DBetaDstVectorSize == 0)), - "Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please " - "check!"); - static_assert( (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0), "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please " "check!"); + static_assert( + ((MThreadSliceSize % DGammaDstVectorSize == 0) || + (MThreadSliceSize % DBetaDstVectorSize == 0)), + "Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please " + "check!"); + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; @@ -298,7 +297,7 @@ struct DeviceNormalizationBwdGammaBetaImpl GridDesc_M dgamma_grid_desc_m_; GridDesc_M dbeta_grid_desc_m_; - index_t MRaw_; // invarient length + index_t MRaw_; // Invariant length index_t KRaw_; // reduce length }; @@ -457,6 +456,21 @@ struct DeviceNormalizationBwdGammaBetaImpl { return std::make_unique(); } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationBwdGammaBetaImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "VectorSize_DY" << DYSrcVectorSize << "_X" << XSrcVectorSize ; + str << "_DGamma" << DGammaDstVectorSize << "_DBeta" << DBetaDstVectorSize << ">"; + // clang-format on + + return str.str(); + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp index 254d60ea38..fa7b9cbda0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp @@ -19,7 +19,7 @@ namespace tensor_operation { namespace device { // Y = Normalization(X, Beta, Gamma) -// M: Invarient length +// M: Invariant length // K: Reduce length (Calculate mean and variance along K dimension) // eg. Length = [N, C, H, W], reduce dim = [C, H, W] // Then, M = N, K = C * H * W @@ -263,7 +263,7 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwdinvariant_lowest_length_); - if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp index 6a117920f4..7fe6502bab 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp @@ -108,7 +108,7 @@ namespace tensor_operation { namespace device { // Y = Normalization(X, Beta, Gamma) -// M: Invarient length +// M: Invariant length // K: Reduce length (Calculate mean and variance along K dimension) // eg. Length = [N, C, H, W], reduce dim = [C, H, W] // Then, M = N, K = C * H * W @@ -468,7 +468,7 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd +struct GridwiseNormalizationBwdData_mk_to_mk +{ + // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce) + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) || + (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) || + (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize == GammaSrcVectorSize) || + (GammaSrcVectorDim == 1 && KThreadSliceSize == GammaSrcVectorSize)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + ((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize == MeanInvStdSrcVectorSize) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize == MeanInvStdSrcVectorSize)), + "Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"); + + static_assert(((DXDstVectorDim == 0 && MThreadSliceSize == DXDstVectorSize) || + (DXDstVectorDim == 1 && KThreadSliceSize == DXDstVectorSize)), + "Invalid thread slice sizes and/or dx vector sizes configuration, please check!"); + + using ThreadClusterLengths_M_K = Sequence; + + using DYThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using XThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using GammaThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using MeanInvStdThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using DXThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = DYThreadBufferDimAccessOrder; + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + using BlockwiseSumReduce = PartitionedBlockwiseReduction; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k, + const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& gamma_grid_desc_m_k, + const GridDesc_M_K& mean_grid_desc_m_k, + const GridDesc_M_K& inv_std_grid_desc_m_k, + const GridDesc_M_K& dx_grid_desc_m_k, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DXDataType* const __restrict__ p_dx_global) + { + // LDS + __shared__ ComputeDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + // Global + const auto dy_global_val_buf = make_dynamic_buffer( + p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize()); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize()); + + const auto inv_std_global_val_buf = make_dynamic_buffer( + p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize()); + + auto dx_global_val_buf = make_dynamic_buffer( + p_dx_global, dx_grid_desc_m_k.GetElementSpaceSize()); + + // VGPR + auto dy_thread_buf = StaticBuffer{}; + + auto x_thread_buf = StaticBuffer{}; + + auto gamma_thread_buf = StaticBuffer{}; + + auto mean_thread_buf = StaticBuffer{}; + + auto inv_std_thread_buf = StaticBuffer{}; + + auto dx_thread_buf = StaticBuffer{}; + + auto ds_thread_buf = + StaticBuffer{}; + + auto db_thread_buf = + StaticBuffer{}; + + // thread id + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + // IO + auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2( + dy_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_mean_load = + ThreadwiseTensorSliceTransfer_v2( + mean_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_inv_std_load = + ThreadwiseTensorSliceTransfer_v2( + inv_std_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dx_store = + ThreadwiseTensorSliceTransfer_v1r3( + dx_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize), + PassThroughOp{}); + + ComputeDataType reduce_size = type_convert( + dy_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + ds_thread_buf(I) = type_convert(0.0f); + db_thread_buf(I) = type_convert(0.0f); + }); + + // Separate sweep once and sweep twice pipeline + // Sweep once: for small k, if KThreadClusterSize * KThreadSliceSize > K + // we don't need to use loop to read x, dy, gamma twice + if constexpr(SweepOnce) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + threadwise_mean_load.Run(mean_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + mean_thread_buf); + + threadwise_inv_std_load.Run(inv_std_grid_desc_m_k, + inv_std_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + inv_std_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + x_thread_buf[offset_m_k]; + + db_thread_buf(offset_m) += + dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k]; + }); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I)); + block_sync_lds(); + BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I)); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + // b = (db * x_mean - ds) * rstd ** (3) / reduce_size + // c = -b * x_mean - db * rstd / reduce_size + // dx = rstd * dy * gamma + b * x + c + + ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] - + ds_thread_buf[offset_m]; + + b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] / reduce_size; + + ComputeDataType c = -b * mean_thread_buf(offset_m_k); + + c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size; + + dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] + + b * x_thread_buf[offset_m_k] + c; + }); + }); + + threadwise_dx_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + dx_thread_buf, + dx_grid_desc_m_k, + dx_global_val_buf); + + } // end of sweep once + else // Sweep Twice pipeline + { + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + x_thread_buf[offset_m_k]; + + db_thread_buf(offset_m) += + dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k]; + }); + }); + } // end of first sweep + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I)); + block_sync_lds(); + BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I)); + }); + + // reverse read for using dy, gamma and x in the cache + constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize); + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; + + // move to tail + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k); + + // move from start to tail + threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + threadwise_mean_load.Run(mean_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + mean_thread_buf); + + threadwise_inv_std_load.Run(inv_std_grid_desc_m_k, + inv_std_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + inv_std_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + // b = (db * x_mean - ds) * rstd ** (3) / reduce_size + // c = -b * x_mean - db * rstd / reduce_size + // dx = rstd * dy * gamma + b * x + c + + ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] - + ds_thread_buf[offset_m]; + + b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] / reduce_size; + + ComputeDataType c = -b * mean_thread_buf(offset_m_k); + + c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size; + + dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] + + b * x_thread_buf[offset_m_k] + c; + }); + }); + + threadwise_dx_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + dx_thread_buf, + dx_grid_desc_m_k, + dx_global_val_buf); + + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_bwd_step_m_k); + threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, + thread_copy_bwd_step_m_k); + threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, + thread_copy_bwd_step_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k); + } + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp index e80c360bb2..21248e3a0a 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp @@ -35,7 +35,7 @@ template struct GridwiseNormalizationBwdGammaBeta_mk_to_k { - // if we just check ThreadSliceSize & VectorSize == 0, the performance may be poor + // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce) static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) || (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)), "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); @@ -44,6 +44,15 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)), "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + // do not force SliceSize == MeanInvStdSrcVectorSize for groupnorm + static_assert( + ((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0)), + "Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"); + + static_assert(MThreadSliceSize == DGammaDstVectorSize && MThreadSliceSize == DBetaDstVectorSize, + "Invalid thread slice sizes and/or dx vector sizes configuration, please check!"); + using ThreadClusterLengths_M_K = Sequence; using DYThreadBufferDimAccessOrder = diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp index 37f875d07a..2cc9a50ba0 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp @@ -16,6 +16,31 @@ namespace ck { namespace tensor_operation { namespace host { +// def normalization_backward_x(dy, x, gamma, x_mean, rstd, reduce_axis, reduce_size): +// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True) +// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True) +// b = (db * x_mean - ds) * rstd ** (3) / reduce_size +// c = -b * x_mean - db * rstd / reduce_size +// dx = rstd * dy * gamma + b * x + c +// return dx + +// def normalization_backward_gamma_beta(dy, x, x_mean, rstd, reduce_axis): +// # Assume shape of gamma and beta are the same +// dgamma = np.sum(dy * (x - x_mean) * rstd, axis=reduce_axis, keepdims=True) +// dbeta = np.sum(dy, axis=reduce_axis, keepdims=True) +// return dgamma, dbeta + +// def groupnorm_backward(dy, x, gamma, x_mean, rstd): +// # dy, x = [N, H, W, G, C], gamma = [1, 1, 1, G, C], x_mean, rstd = [N, 1, 1, G, 1] +// N, H, W, G, C = x.shape +// dx = normalization_input_backward( +// dy, x, gamma, x_mean, rstd, (1, 2, 4), H * W * C) +// dgamma, dbeta = normalization_gamma_beta_backward( +// dy, x, x_mean, rstd, (0, 1, 2)) +// return dx, dgamma, dbeta + +// Reference (Layernorm and groupnorm): +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cpu/group_norm_kernel.cpp#L655 template +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +#ifdef CK_ENABLE_FP32 +// FP32 +void add_device_groupnorm_bwd_data_f32_instances( + std::vector>>&); +#endif +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceNormalizationBwdData> +{ + using DeviceOp = DeviceNormalizationBwdData; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_groupnorm_bwd_data_f32_instances(op_ptrs); + } +#endif + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_data.hpp new file mode 100644 index 0000000000..c46cdec336 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_data.hpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +#ifdef CK_ENABLE_FP16 +// FP16 +void add_device_layernorm2d_bwd_data_f16_instances( + std::vector>>&); +#endif +#ifdef CK_ENABLE_FP32 +// FP32 +void add_device_layernorm2d_bwd_data_f32_instances( + std::vector>>&); +#endif +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceNormalizationBwdData> +{ + using DeviceOp = DeviceNormalizationBwdData; + + static auto GetInstances() + { + std::vector> op_ptrs; +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 2 && NumReduceDim == 1) + { + add_device_layernorm2d_bwd_data_f16_instances(op_ptrs); + } + } +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 2 && NumReduceDim == 1) + { + add_device_layernorm2d_bwd_data_f32_instances(op_ptrs); + } + } +#endif + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd.hpp b/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd.hpp index 29c9f8b2c0..d19e50297f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd.hpp @@ -20,15 +20,15 @@ namespace instance { // FP16 void add_device_normalization_fwd_rank_2_1_f16_instances( std::vector< - std::unique_ptr>>&); + std::unique_ptr>>&); void add_device_normalization_fwd_rank_4_3_f16_instances( std::vector< - std::unique_ptr>>&); + std::unique_ptr>>&); void add_device_normalization_fwd_rank_5_3_f16_instances( std::vector< - std::unique_ptr>>&); + std::unique_ptr>>&); #endif #ifdef CK_ENABLE_FP32 // FP32 @@ -76,7 +76,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v && - is_same_v) + is_same_v) { if constexpr(Rank == 2 && NumReduceDim == 1) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd_swish.hpp b/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd_swish.hpp index 563f164fd2..cc15e7bacc 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd_swish.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/normalization_fwd_swish.hpp @@ -19,7 +19,7 @@ namespace instance { // FP16 void add_device_normalization_fwd_rank_5_3_swish_f16_instances( - std::vector>>&); + std::vector>>&); // FP32 void add_device_normalization_fwd_rank_5_3_swish_f32_instances( @@ -61,7 +61,7 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && - is_same_v) + is_same_v) { if constexpr(Rank == 5 && NumReduceDim == 3) { diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000..9f3dd9d94c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/CMakeLists.txt @@ -0,0 +1,8 @@ +set(DEVICE_NORMALIZATION_bwd_data_INSTANCES) + +list(APPEND DEVICE_NORMALIZATION_bwd_data_INSTANCES + device_groupnorm_bwd_data_f32_instance.cpp + device_layernorm2d_bwd_data_f16_instance.cpp + device_layernorm2d_bwd_data_f32_instance.cpp) + +add_instance_library(device_normalization_bwd_data_instance ${DEVICE_NORMALIZATION_bwd_data_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_groupnorm_bwd_data_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_groupnorm_bwd_data_f32_instance.cpp new file mode 100644 index 0000000000..7b4974d648 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_groupnorm_bwd_data_f32_instance.cpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_bwd_data_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_groupnorm_bwd_data_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_groupnorm_bwd_data_f32_generic_instance{}); + add_device_operation_instances(instances, device_groupnorm_bwd_data_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f16_instance.cpp new file mode 100644 index 0000000000..097f9a3814 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f16_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_bwd_data_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_layernorm2d_bwd_data_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_layernorm_bwd_data_f16_generic_instance<2, 1>{}); + add_device_operation_instances(instances, device_layernorm_bwd_data_f16_instances<2, 1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f32_instance.cpp new file mode 100644 index 0000000000..d885a77e5c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/device_layernorm2d_bwd_data_f32_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_bwd_data_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_layernorm2d_bwd_data_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_layernorm_bwd_data_f32_generic_instance<2, 1>{}); + add_device_operation_instances(instances, device_layernorm_bwd_data_f32_instances<2, 1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_data/normalization_bwd_data_instance_common.hpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/normalization_bwd_data_instance_common.hpp new file mode 100644 index 0000000000..4f72a8782b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_data/normalization_bwd_data_instance_common.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using device_layernorm_bwd_data_f16_instances = + // clang-format off + std::tuple < + // DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDXFastestDimReduced, DXDstVectorSize> + DeviceNormalizationBwdDataImpl, + DeviceNormalizationBwdDataImpl, + DeviceNormalizationBwdDataImpl + // clang-format on + >; + +template +using device_layernorm_bwd_data_f16_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationBwdDataImpl + // clang-format on + >; + +template +using device_layernorm_bwd_data_f32_instances = + // clang-format off + std::tuple < + // DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDXFastestDimReduced, DXDstVectorSize> + DeviceNormalizationBwdDataImpl, + DeviceNormalizationBwdDataImpl + // clang-format on + >; + +template +using device_layernorm_bwd_data_f32_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationBwdDataImpl + // clang-format on + >; + +using device_groupnorm_bwd_data_f32_instances = + // clang-format off + std::tuple < + // DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsGammaFastestDimReduced, GammaSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, IsDXFastestDimReduced, DXDstVectorSize> + DeviceNormalizationBwdDataImpl, + DeviceNormalizationBwdDataImpl + // clang-format on + >; + +using device_groupnorm_bwd_data_f32_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationBwdDataImpl + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/CMakeLists.txt new file mode 100644 index 0000000000..686fb5e665 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/CMakeLists.txt @@ -0,0 +1,8 @@ +set(DEVICE_NORMALIZATION_BWD_GAMMA_BETA_INSTANCES) + +list(APPEND DEVICE_NORMALIZATION_BWD_GAMMA_BETA_INSTANCES + device_groupnorm_bwd_gamma_beta_f32_instance.cpp + device_layernorm2d_bwd_gamma_beta_f16_instance.cpp + device_layernorm2d_bwd_gamma_beta_f32_instance.cpp) + +add_instance_library(device_normalization_bwd_gamma_beta_instance ${DEVICE_NORMALIZATION_BWD_GAMMA_BETA_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_groupnorm_bwd_gamma_beta_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_groupnorm_bwd_gamma_beta_f32_instance.cpp new file mode 100644 index 0000000000..8eaf8a5684 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_groupnorm_bwd_gamma_beta_f32_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_bwd_gamma_beta_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_groupnorm_bwd_gamma_beta_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_groupnorm_bwd_gamma_beta_f32_instances{}); + add_device_operation_instances(instances, + device_groupnorm_bwd_gamma_beta_f32_generic_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp new file mode 100644 index 0000000000..aa399f56ec --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_bwd_gamma_beta_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f16_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_layernorm_bwd_gamma_beta_f16_generic_instance<2, 1>{}); + add_device_operation_instances(instances, + device_layernorm_bwd_gamma_beta_f16_instances<2, 1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp new file mode 100644 index 0000000000..ba2966ba37 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "normalization_bwd_gamma_beta_instance_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f32_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_layernorm_bwd_gamma_beta_f32_generic_instance<2, 1>{}); + add_device_operation_instances(instances, + device_layernorm_bwd_gamma_beta_f32_instances<2, 1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/normalization_bwd_gamma_beta_instance_common.hpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/normalization_bwd_gamma_beta_instance_common.hpp new file mode 100644 index 0000000000..3f239527a3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/normalization_bwd_gamma_beta_instance_common.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using device_layernorm_bwd_gamma_beta_f16_instances = + // clang-format off + std::tuple < + // DYDataType, XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, DGammaDstVectorSize, DBetaDstVectorSize> + DeviceNormalizationBwdGammaBetaImpl, + DeviceNormalizationBwdGammaBetaImpl, + DeviceNormalizationBwdGammaBetaImpl + // clang-format on + >; + +template +using device_layernorm_bwd_gamma_beta_f16_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationBwdGammaBetaImpl + // clang-format on + >; + +template +using device_layernorm_bwd_gamma_beta_f32_instances = + // clang-format off + std::tuple < + // DYDataType, XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, DGammaDstVectorSize, DBetaDstVectorSize> + DeviceNormalizationBwdGammaBetaImpl, + DeviceNormalizationBwdGammaBetaImpl + // clang-format on + >; + +template +using device_layernorm_bwd_gamma_beta_f32_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationBwdGammaBetaImpl + // clang-format on + >; + +using device_groupnorm_bwd_gamma_beta_f32_instances = + // clang-format off + std::tuple < + // DYDataType, XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, IsDYFastestDimReduced, DYSrcVectorSize, IsXFastestDimReduced, XSrcVectorSize, IsMeanInvStdFastestDimReduced, MeanInvStdSrcVectorSize, DGammaDstVectorSize, DBetaDstVectorSize> + DeviceNormalizationBwdGammaBetaImpl, + DeviceNormalizationBwdGammaBetaImpl + // clang-format on + >; + +using device_groupnorm_bwd_gamma_beta_f32_generic_instance = std::tuple< + // clang-format off + DeviceNormalizationBwdGammaBetaImpl + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_f16_instance.cpp index 0f8bab973e..9b58e12e1d 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_f16_instance.cpp @@ -11,7 +11,7 @@ namespace instance { using Pass = ck::tensor_operation::element_wise::PassThrough; void add_device_normalization_fwd_rank_5_3_f16_instances( - std::vector>>& + std::vector>>& instances) { add_device_operation_instances(instances, diff --git a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_swish_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_swish_f16_instance.cpp index 9fbbab64e7..fe42f3dfec 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_swish_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_groupnorm_fwd_swish_f16_instance.cpp @@ -11,7 +11,7 @@ namespace instance { using Swish = ck::tensor_operation::element_wise::Swish; void add_device_normalization_fwd_rank_5_3_swish_f16_instances( - std::vector>>& + std::vector>>& instances) { add_device_operation_instances(instances, diff --git a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm2d_fwd_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm2d_fwd_f16_instance.cpp index bfc2e465df..8fcab734d3 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm2d_fwd_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm2d_fwd_f16_instance.cpp @@ -11,7 +11,7 @@ namespace instance { using Pass = ck::tensor_operation::element_wise::PassThrough; void add_device_normalization_fwd_rank_2_1_f16_instances( - std::vector>>& + std::vector>>& instances) { add_device_operation_instances(instances, diff --git a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm4d_fwd_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm4d_fwd_f16_instance.cpp index 690489bbad..132eef0146 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm4d_fwd_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization_fwd/device_layernorm4d_fwd_f16_instance.cpp @@ -11,7 +11,7 @@ namespace instance { using Pass = ck::tensor_operation::element_wise::PassThrough; void add_device_normalization_fwd_rank_4_3_f16_instances( - std::vector>>& + std::vector>>& instances) { add_device_operation_instances(instances, diff --git a/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp b/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp index 60a55dd6e1..9486a033de 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp +++ b/library/src/tensor_operation_instance/gpu/normalization_fwd/normalization_fwd_instance_common.hpp @@ -23,24 +23,24 @@ using device_normalization_f16_instances = // clang-format off std::tuple < // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, // irregular size - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl, - DeviceNormalizationFwdImpl + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, // irregular size + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl, + DeviceNormalizationFwdImpl // clang-format on >; @@ -49,31 +49,31 @@ using device_normalization_splitk_f16_instances = // clang-format off std::tuple < // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize, SaveMeanInvStdScalarPerVector> - DeviceNormalizationFwdSplitKImpl, // irregular size - DeviceNormalizationFwdSplitKImpl, // irregular size - DeviceNormalizationFwdSplitKImpl, // irregular size - DeviceNormalizationFwdSplitKImpl, // irregular size - DeviceNormalizationFwdSplitKImpl, // irregular size - DeviceNormalizationFwdSplitKImpl, // irregular size - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl, - DeviceNormalizationFwdSplitKImpl + DeviceNormalizationFwdSplitKImpl, // irregular size + DeviceNormalizationFwdSplitKImpl, // irregular size + DeviceNormalizationFwdSplitKImpl, // irregular size + DeviceNormalizationFwdSplitKImpl, // irregular size + DeviceNormalizationFwdSplitKImpl, // irregular size + DeviceNormalizationFwdSplitKImpl, // irregular size + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl, + DeviceNormalizationFwdSplitKImpl // clang-format on >; template using device_normalization_f16_generic_instance = std::tuple< // clang-format off - DeviceNormalizationFwdImpl + DeviceNormalizationFwdImpl // clang-format on >; diff --git a/profiler/include/profiler/profile_groupnorm_bwd_data_impl.hpp b/profiler/include/profiler/profile_groupnorm_bwd_data_impl.hpp new file mode 100644 index 0000000000..55ea08e0db --- /dev/null +++ b/profiler/include/profiler/profile_groupnorm_bwd_data_impl.hpp @@ -0,0 +1,250 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_groupnorm_bwd_data_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector length) +{ + // we don't need DGamma and DBeta here, just for reference class + using DGammaDataType = DXDataType; + using DBetaDataType = DXDataType; + + if(length.size() != 5) + return false; + + index_t N = length[0]; + index_t G = length[3]; + index_t C = length[4]; + + std::vector reduce_dim = {1, 2, 4}; + std::vector gammaLength = {G, C}; + + Tensor dy(length); + Tensor x(length); + Tensor gamma({G, C}); + Tensor mean({N, G}); + Tensor inv_std({N, G}); + Tensor dx(length); + + Tensor host_dx(length); + Tensor host_dgamma({G, C}); + Tensor host_dbeta({G, C}); + + std::vector strideDy = + std::vector{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; + std::vector strideX = strideDy; + std::vector strideDx = strideDy; + + std::vector strideGamma = {0, 0, 0, C, 1}; + std::vector strideMeanInvStd = {G, 0, 0, 1, 0}; + + switch(init_method) + { + case 0: + dy.GenerateTensorValue(GeneratorTensor_1{}); + x.GenerateTensorValue(GeneratorTensor_1{}); + gamma.GenerateTensorValue(GeneratorTensor_1{}); + mean.GenerateTensorValue(GeneratorTensor_1{}); + inv_std.GenerateTensorValue(GeneratorTensor_1{}); + dx.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 1: + dy.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + gamma.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + mean.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + inv_std.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + dx.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + dy.GenerateTensorValue(GeneratorTensor_3{0, 1}); + x.GenerateTensorValue(GeneratorTensor_3{0, 1}); + gamma.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + mean.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + inv_std.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + dx.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); + DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize()); + + dy_dev.ToDevice(dy.mData.data()); + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + mean_dev.ToDevice(mean.mData.data()); + inv_std_dev.ToDevice(inv_std.mData.data()); + + // add device normalization instances + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationBwdData; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceInstance = + ck::tensor_operation::host::ReferenceGroupnormBwd; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, length); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + } + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer(length, + strideDy, + strideX, + strideGamma, + strideMeanInvStd, + strideMeanInvStd, + strideDx, + reduce_dim, + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = ", length, ", ") << std::endl; + } + + continue; + } + + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = dy.mDesc.GetElementSize() * sizeof(DYDataType) + + x.mDesc.GetElementSize() * sizeof(XDataType) + + gamma.mDesc.GetElementSize() * sizeof(GammaDataType) + + mean.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + inv_std.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + dx.mDesc.GetElementSize() * sizeof(DXDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + dx_dev.FromDevice(dx.mData.data()); + bool pass = ck::utils::check_err( + dx.mData, host_dx.mData, "Error: Incorrect results", 1e-3, 1e-3); + + if(do_log) + { + LogRangeAsType(std::cout << "dy : ", dy.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_dx : ", host_dx.mData, ",") << std::endl; + LogRangeAsType(std::cout << "dx : ", dx.mData, ",") << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", length, ",") << ", "; + LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s," + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_layernorm_bwd_data_impl.hpp b/profiler/include/profiler/profile_layernorm_bwd_data_impl.hpp new file mode 100644 index 0000000000..e88a06122d --- /dev/null +++ b/profiler/include/profiler/profile_layernorm_bwd_data_impl.hpp @@ -0,0 +1,255 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_data.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_layernorm_bwd_data_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector length) +{ + // we don't need DGamma and DBeta here, just for reference class + using DGammaDataType = DXDataType; + using DBetaDataType = DXDataType; + + if(length.size() != Rank || Rank < 2) + return false; + + // Assume normalize dimension except for batch (first) dimension + std::vector reduce_length{length.begin() + 1, length.end()}; + std::vector reduce_dim; + for(int i = 1; i < Rank; ++i) + reduce_dim.push_back(i); + + Tensor dy(length); + Tensor x(length); + Tensor gamma(reduce_length); + Tensor mean({length[0]}); + Tensor inv_std({length[0]}); + Tensor dx(length); + + Tensor host_dx(length); + Tensor host_dgamma(reduce_length); + Tensor host_dbeta(reduce_length); + + std::vector strideDy = + std::vector{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; + std::vector strideX = strideDy; + std::vector strideDx = strideDy; + + std::vector strideGamma = strideDy; + strideGamma[0] = 0; + + std::vector strideMeanInvStd{Rank, 0}; + strideMeanInvStd[0] = 1; + + switch(init_method) + { + case 0: + dy.GenerateTensorValue(GeneratorTensor_1{}); + x.GenerateTensorValue(GeneratorTensor_1{}); + gamma.GenerateTensorValue(GeneratorTensor_1{}); + mean.GenerateTensorValue(GeneratorTensor_1{}); + inv_std.GenerateTensorValue(GeneratorTensor_1{}); + dx.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 1: + dy.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + gamma.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + mean.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + inv_std.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + dx.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + dy.GenerateTensorValue(GeneratorTensor_3{0, 1}); + x.GenerateTensorValue(GeneratorTensor_3{0, 1}); + gamma.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + mean.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + inv_std.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + dx.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); + DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize()); + + dy_dev.ToDevice(dy.mData.data()); + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + mean_dev.ToDevice(mean.mData.data()); + inv_std_dev.ToDevice(inv_std.mData.data()); + + constexpr int NumReduceDim = Rank - 1; + + // add device normalization instances + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationBwdData; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceInstance = + ck::tensor_operation::host::ReferenceLayernormBwd; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, length); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + } + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer(length, + strideDy, + strideX, + strideGamma, + strideMeanInvStd, + strideMeanInvStd, + strideDx, + reduce_dim, + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = ", length, ", ") << std::endl; + } + + continue; + } + + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = dy.mDesc.GetElementSize() * sizeof(DYDataType) + + x.mDesc.GetElementSize() * sizeof(XDataType) + + gamma.mDesc.GetElementSize() * sizeof(GammaDataType) + + mean.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + inv_std.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + dx.mDesc.GetElementSize() * sizeof(DXDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + dx_dev.FromDevice(dx.mData.data()); + bool pass = ck::utils::check_err( + dx.mData, host_dx.mData, "Error: Incorrect results", 1e-3, 1e-3); + + if(do_log) + { + LogRangeAsType(std::cout << "dy : ", dy.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_dx : ", host_dx.mData, ",") << std::endl; + LogRangeAsType(std::cout << "dx : ", dx.mData, ",") << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", length, ",") << ", "; + LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s," + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 0af3107157..7674b3b4f0 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -16,7 +16,9 @@ set(PROFILER_SOURCES profile_grouped_conv_fwd.cpp profile_grouped_conv_bwd_weight.cpp profile_reduce.cpp + profile_groupnorm_bwd_data.cpp profile_groupnorm_fwd.cpp + profile_layernorm_bwd_data.cpp profile_layernorm_fwd.cpp profile_max_pool3d_fwd.cpp profile_avg_pool3d_bwd.cpp @@ -78,6 +80,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_w target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) diff --git a/profiler/src/profile_groupnorm_bwd_data.cpp b/profiler/src/profile_groupnorm_bwd_data.cpp new file mode 100644 index 0000000000..f9fea1db55 --- /dev/null +++ b/profiler/src/profile_groupnorm_bwd_data.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/data_type_enum.hpp" +#include "profiler/profile_groupnorm_bwd_data_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +struct groupnormBwdDataArgParser +{ + std::unordered_map> long_opts = {{"length", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help_groupnorm_bwd_data() +{ + // eg: ckProfiler groupnorm_bwd_data 1 0 2 0 1 --length 1 16 16 32 40 + std::cout << "arg1: data type (0: fp16; 1: fp32)\n" + << "arg2: verification (0: no; 1: yes)\n" + << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg4: print tensor value (0: no; 1: yes)\n" + << "arg5: time kernel (0=no, 1=yes)\n" + << "--length: tensor extents (e.g, --length 1 16 16 32 40) \n" + << std::endl; +} + +int profile_groupnorm_bwd_data(int argc, char* argv[]) +{ + if(argc <= 2) + { + print_help_groupnorm_bwd_data(); + return 0; + } + + groupnormBwdDataArgParser arg_parser; + + // short unnamed options + const ck::DataTypeEnum data_type = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + + // parse the long options + arg_parser(argc, argv); + const std::vector length = arg_parser.long_opts["length"]; + + using F32 = float; + + if(length.size() == 5) + { + if(data_type == ck::DataTypeEnum::Float) + { + ck::profiler::profile_groupnorm_bwd_data_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else + { + throw std::runtime_error("not implemented yet"); + } + } + else + { + throw std::runtime_error("length should be 5"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("groupnorm_bwd_data", + "Group Normalization", + profile_groupnorm_bwd_data); diff --git a/profiler/src/profile_groupnorm_fwd.cpp b/profiler/src/profile_groupnorm_fwd.cpp index 3ba2f751cc..9a595bf7a7 100644 --- a/profiler/src/profile_groupnorm_fwd.cpp +++ b/profiler/src/profile_groupnorm_fwd.cpp @@ -98,7 +98,7 @@ int profile_groupnorm(int argc, char* argv[]) } else if(data_type == ck::DataTypeEnum::Half) { - ck::profiler::profile_groupnorm_impl( + ck::profiler::profile_groupnorm_impl( do_verification, init_method, do_log, time_kernel, length); } else diff --git a/profiler/src/profile_layernorm_bwd_data.cpp b/profiler/src/profile_layernorm_bwd_data.cpp new file mode 100644 index 0000000000..1f364d79ba --- /dev/null +++ b/profiler/src/profile_layernorm_bwd_data.cpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/data_type_enum.hpp" +#include "profiler/profile_layernorm_bwd_data_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +struct layernormBwdDataArgParser +{ + std::unordered_map> long_opts = {{"length", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help_layernorm_bwd_data() +{ + // eg: ckProfiler layernorm_bwd_data 0 0 2 0 1 --length 1502 4096 + std::cout << "arg1: data type (0: fp16; 1: fp32)\n" + << "arg2: verification (0: no; 1: yes)\n" + << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg4: print tensor value (0: no; 1: yes)\n" + << "arg5: time kernel (0=no, 1=yes)\n" + << "--length: tensor extents (e.g, --length 1024 1024) \n" + << std::endl; +} + +int profile_layernorm_bwd_data(int argc, char* argv[]) +{ + if(argc <= 2) + { + print_help_layernorm_bwd_data(); + return 0; + } + + layernormBwdDataArgParser arg_parser; + + // short unnamed options + const ck::DataTypeEnum data_type = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + + // parse the long options + arg_parser(argc, argv); + const std::vector length = arg_parser.long_opts["length"]; + + using F16 = ck::half_t; + using F32 = float; + + if(length.size() == 2) + { + constexpr int rank = 2; + + if(data_type == ck::DataTypeEnum::Half) + { + ck::profiler::profile_layernorm_bwd_data_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else if(data_type == ck::DataTypeEnum::Float) + { + ck::profiler::profile_layernorm_bwd_data_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else + { + throw std::runtime_error("not implemented yet"); + } + } + else + { + throw std::runtime_error("not implemented yet"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("layernorm_bwd_data", + "Layer Normalization", + profile_layernorm_bwd_data); diff --git a/profiler/src/profile_layernorm_fwd.cpp b/profiler/src/profile_layernorm_fwd.cpp index 9bd66e0cb8..a261bd7418 100644 --- a/profiler/src/profile_layernorm_fwd.cpp +++ b/profiler/src/profile_layernorm_fwd.cpp @@ -104,7 +104,7 @@ int profile_layernorm(int argc, char* argv[]) if(data_type == ck::DataTypeEnum::Half) { - ck::profiler::profile_layernorm_impl( + ck::profiler::profile_layernorm_impl( do_verification, init_method, do_log, time_kernel, length); } else if(data_type == ck::DataTypeEnum::Float) @@ -125,4 +125,4 @@ int profile_layernorm(int argc, char* argv[]) return 0; } -REGISTER_PROFILER_OPERATION("layernorm", "Layer Normalization", profile_layernorm); +REGISTER_PROFILER_OPERATION("layernorm_fwd", "Layer Normalization", profile_layernorm); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b325a3a7f8..6f7e18b0e7 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -140,6 +140,7 @@ add_subdirectory(grouped_convnd_bwd_weight) add_subdirectory(block_to_ctile_map) add_subdirectory(softmax) add_subdirectory(normalization_fwd) +add_subdirectory(normalization_bwd_data) add_subdirectory(data_type) add_subdirectory(elementwise_normalization) add_subdirectory(batchnorm) diff --git a/test/normalization_bwd_data/CMakeLists.txt b/test/normalization_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000..1b6decfed7 --- /dev/null +++ b/test/normalization_bwd_data/CMakeLists.txt @@ -0,0 +1,13 @@ +add_custom_target(test_normalization_bwd_data) +add_gtest_executable(test_layernorm2d_bwd_data_fp32 test_layernorm2d_bwd_data_fp32.cpp) +if(result EQUAL 0) + target_link_libraries(test_layernorm2d_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) + add_dependencies(test_normalization_bwd_data test_layernorm2d_bwd_data_fp32) +endif() + +add_gtest_executable(test_groupnorm_bwd_data_fp32 test_groupnorm_bwd_data_fp32.cpp) +if(result EQUAL 0) + target_link_libraries(test_groupnorm_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) + add_dependencies(test_normalization_bwd_data test_groupnorm_bwd_data_fp32) +endif() + diff --git a/test/normalization_bwd_data/test_groupnorm_bwd_data_fp32.cpp b/test/normalization_bwd_data/test_groupnorm_bwd_data_fp32.cpp new file mode 100644 index 0000000000..a7860955cd --- /dev/null +++ b/test/normalization_bwd_data/test_groupnorm_bwd_data_fp32.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_groupnorm_bwd_data_impl.hpp" + +using F16 = ck::half_t; +using F32 = float; +using ck::index_t; + +template +class TestgroupnormBwdData : public ::testing::Test +{ + protected: + using DYDataType = std::tuple_element_t<0, Tuple>; + using XDataType = std::tuple_element_t<1, Tuple>; + using GammaDataType = std::tuple_element_t<2, Tuple>; + using MeanInvStdDataType = std::tuple_element_t<3, Tuple>; + using ComputeDataType = std::tuple_element_t<4, Tuple>; + using DXDataType = std::tuple_element_t<5, Tuple>; + + void Run() + { + // Bwd data: [N, H, W, G, C], reduce H, W, C + std::vector> lengths = {{1, 1, 1, 1, 1}, + {1, 2, 3, 4, 5}, + {256, 9, 9, 9, 9}, + {1, 64, 64, 32, 10}, + {1, 32, 32, 32, 20}, + {1, 16, 16, 32, 40}}; + + for(auto length : lengths) + { + bool success = ck::profiler::profile_groupnorm_bwd_data_impl( + true, 2, false, false, length); + EXPECT_TRUE(success); + } + } +}; + +using KernelTypes = ::testing::Types< + // DYDataType XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType> + std::tuple>; + +TYPED_TEST_SUITE(TestgroupnormBwdData, KernelTypes); +TYPED_TEST(TestgroupnormBwdData, Test_FP32) { this->Run(); } diff --git a/test/normalization_bwd_data/test_layernorm2d_bwd_data_fp32.cpp b/test/normalization_bwd_data/test_layernorm2d_bwd_data_fp32.cpp new file mode 100644 index 0000000000..870f24d064 --- /dev/null +++ b/test/normalization_bwd_data/test_layernorm2d_bwd_data_fp32.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_layernorm_bwd_data_impl.hpp" + +using F16 = ck::half_t; +using F32 = float; +using ck::index_t; + +template +class TestLayernorm2dBwdData : public ::testing::Test +{ + protected: + using DYDataType = std::tuple_element_t<0, Tuple>; + using XDataType = std::tuple_element_t<1, Tuple>; + using GammaDataType = std::tuple_element_t<2, Tuple>; + using MeanInvStdDataType = std::tuple_element_t<3, Tuple>; + using ComputeDataType = std::tuple_element_t<4, Tuple>; + using DXDataType = std::tuple_element_t<5, Tuple>; + + void Run() + { + // Bwd data: [N, D], reduce D + std::vector> lengths = { + {4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}}; + + for(auto length : lengths) + { + bool success = + ck::profiler::profile_layernorm_bwd_data_impl(true, 2, false, false, length); + EXPECT_TRUE(success); + } + } +}; + +using KernelTypes = ::testing::Types< + // DYDataType XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType> + std::tuple>; + +TYPED_TEST_SUITE(TestLayernorm2dBwdData, KernelTypes); +TYPED_TEST(TestLayernorm2dBwdData, Test_FP32) { this->Run(); } diff --git a/test/normalization_fwd/test_groupnorm_fwd_fp16.cpp b/test/normalization_fwd/test_groupnorm_fwd_fp16.cpp index 143c725257..c31161fb33 100644 --- a/test/normalization_fwd/test_groupnorm_fwd_fp16.cpp +++ b/test/normalization_fwd/test_groupnorm_fwd_fp16.cpp @@ -47,8 +47,8 @@ class TestGroupnorm : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> - std::tuple>; + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType> + std::tuple>; TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); TYPED_TEST(TestGroupnorm, Test_FP16) { this->Run(); } diff --git a/test/normalization_fwd/test_groupnorm_fwd_fp32.cpp b/test/normalization_fwd/test_groupnorm_fwd_fp32.cpp index 84a833c793..08d835ed37 100644 --- a/test/normalization_fwd/test_groupnorm_fwd_fp32.cpp +++ b/test/normalization_fwd/test_groupnorm_fwd_fp32.cpp @@ -45,7 +45,7 @@ class TestGroupnorm : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType> std::tuple>; TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); diff --git a/test/normalization_fwd/test_layernorm2d_fwd_fp16.cpp b/test/normalization_fwd/test_layernorm2d_fwd_fp16.cpp index cc49ebe0ae..3234b2e159 100644 --- a/test/normalization_fwd/test_layernorm2d_fwd_fp16.cpp +++ b/test/normalization_fwd/test_layernorm2d_fwd_fp16.cpp @@ -41,8 +41,8 @@ class TestLayernorm2d : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> - std::tuple>; + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType> + std::tuple>; TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); TYPED_TEST(TestLayernorm2d, Test_FP16) { this->Run(); } diff --git a/test/normalization_fwd/test_layernorm4d_fwd_fp16.cpp b/test/normalization_fwd/test_layernorm4d_fwd_fp16.cpp index a3bd388f7f..d1a7b9e3df 100644 --- a/test/normalization_fwd/test_layernorm4d_fwd_fp16.cpp +++ b/test/normalization_fwd/test_layernorm4d_fwd_fp16.cpp @@ -41,8 +41,8 @@ class TestLayernorm4d : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> - std::tuple>; + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, SaveMeanInvStdDataType> + std::tuple>; TYPED_TEST_SUITE(TestLayernorm4d, KernelTypes); TYPED_TEST(TestLayernorm4d, Test_FP16) { this->Run(); } From 12a8883c48b8ac03595bb4c5eb79d2fa53ff4599 Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Mon, 18 Dec 2023 19:35:00 -0800 Subject: [PATCH 28/75] Hip tensor permute unit test (#1068) * adding files for F32 example * adding functioning implementation with scalar multiplication and unary operator support * added fp 16 type check in unary square * updating scalar multiplication as an operator * functioning version with scalar operator * changing strides for col major * updated column major implementation * working column major implementation * cleaned up comments, rearranged/renamed files * small edits to 3d transpose profiler * adding test/profiler/instance files for hipTensor permute unit test * added more test instances * cleaned up errors, randomized input tensor, added more instances * turned off time printouts * removed conflicting transpose profiler * rearranged some files --- .../elementwise_permute_4D_fp16_col.cpp | 11 +- .../elementwise_permute_4D_fp32_col.cpp | 4 +- .../gpu/permute_scale.hpp | 77 +++++++ .../gpu/permute_scale/CMakeLists.txt | 2 + .../device_permute_scale_instances.cpp | 56 +++++ .../device_transpose_instances_3d.cpp | 8 - test/CMakeLists.txt | 1 + test/permute_scale/CMakeLists.txt | 6 + test/permute_scale/test_permute_scale.cpp | 36 +++ .../permute_scale/test_permute_scale_impl.hpp | 212 ++++++++++++++++++ 10 files changed, 399 insertions(+), 14 deletions(-) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.cpp create mode 100644 test/permute_scale/CMakeLists.txt create mode 100644 test/permute_scale/test_permute_scale.cpp create mode 100644 test/permute_scale/test_permute_scale_impl.hpp diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp index 9ed078f77e..f496d26a8a 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -1,5 +1,6 @@ #include #include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" @@ -48,10 +49,8 @@ void host_elementwise4D(HostTensorB& B_nhwc, for(std::size_t n = 0; n < N; ++n) { ADataType tmp_val; - // auto a_val = A_nchw(n, c, h, w); auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; functor_b(tmp_val, a_val); - // functor_a(B_nhwc(n, h, w, c), scale * tmp_val); functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], scale * tmp_val); } @@ -62,12 +61,14 @@ int main() bool do_verification = true; bool time_kernel = true; - std::vector nchw = {4, 2, 1, 8}; - std::vector nhwc = {4, 1, 8, 2}; + std::vector nchw = {16, 8, 32, 64}; + std::vector nhwc = {16, 32, 64, 8}; Tensor a(nchw); Tensor b(nhwc); float scale = 1.f; auto i = 0; + std::mt19937 gen(11939); + std::uniform_int_distribution dis(0, 1); for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w) for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h) for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c) @@ -75,7 +76,7 @@ int main() { a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) + (h * nchw[3]) + w] = i; - i++; + i = dis(gen); } DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp index be8894f2b2..619f481357 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -67,6 +67,8 @@ int main() float scale = 1.f; auto i = 0; + std::mt19937 gen(11939); + std::uniform_int_distribution dis(0, 1); for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w) for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h) for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c) @@ -74,7 +76,7 @@ int main() { a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) + (h * nchw[3]) + w] = i; - i++; + i = dis(gen); } DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp new file mode 100644 index 0000000000..6ea1244c57 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_permute_scale_f16_instances( + std::vector, + ck::Tuple, + PassThrough, + element_wise::UnarySquare, + Scale, + 4>>>&); + +void add_device_permute_scale_f32_instances( + std::vector, + ck::Tuple, + PassThrough, + element_wise::UnarySquare, + Scale, + 4>>>&); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceElementwise> +{ + using DeviceOp = DeviceElementwise; + + static auto GetInstances() + { + std::vector> op_ptrs; + if constexpr(is_same_v> && + is_same_v>) + { + add_device_permute_scale_f32_instances(op_ptrs); + } + else if constexpr(is_same_v> && + is_same_v>) + { + add_device_permute_scale_f16_instances(op_ptrs); + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt new file mode 100644 index 0000000000..8b45c1ab07 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt @@ -0,0 +1,2 @@ +add_instance_library(device_permute_scale_instance + device_permute_scale_instances.cpp) diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.cpp new file mode 100644 index 0000000000..fbbedd52e8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Pass = ck::tensor_operation::element_wise::PassThrough; +using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; +using Scale = ck::tensor_operation::element_wise::Scale; + +// clang-format off +using device_permute_scale_f16_instances = + std::tuple < + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 1, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 8, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 4, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 2, ck::Sequence<1>, ck::Sequence<1>> + >; + +using device_permute_scale_f32_instances = std::tuple< + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 1, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 8, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 4, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 2, ck::Sequence<1>, ck::Sequence<1>> + >; +// clang-format on + +void add_device_permute_scale_f16_instances( + std::vector, ck::Tuple, Pass, UnaryOp, Scale, 4>>>& instances) +{ + add_device_operation_instances(instances, device_permute_scale_f16_instances{}); +} + +void add_device_permute_scale_f32_instances( + std::vector, ck::Tuple, Pass, UnaryOp, Scale, 4>>>& instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/transpose/device_transpose_instances_3d.cpp b/library/src/tensor_operation_instance/gpu/transpose/device_transpose_instances_3d.cpp index 4efeb81885..0357af149c 100644 --- a/library/src/tensor_operation_instance/gpu/transpose/device_transpose_instances_3d.cpp +++ b/library/src/tensor_operation_instance/gpu/transpose/device_transpose_instances_3d.cpp @@ -19,22 +19,14 @@ void add_device_transpose_f16_instances( std::vector, ck::Tuple, PassThrough, 5>>>& instances) { -#ifdef CK_ENABLE_FP16 add_device_operation_instances(instances, device_transpose_f16_instances{}); -#else - ignore = instances; -#endif } void add_device_transpose_f32_instances( std::vector, ck::Tuple, PassThrough, 5>>>& instances) { -#ifdef CK_ENABLE_FP32 add_device_operation_instances(instances, device_transpose_f32_instances{}); -#else - ignore = instances; -#endif } } // namespace instance diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 6f7e18b0e7..94c5f2750f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -150,6 +150,7 @@ add_subdirectory(batched_gemm_multi_d) add_subdirectory(grouped_convnd_bwd_data) add_subdirectory(conv_tensor_rearrange) add_subdirectory(transpose) +add_subdirectory(permute_scale) add_subdirectory(wrapper) if(GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) diff --git a/test/permute_scale/CMakeLists.txt b/test/permute_scale/CMakeLists.txt new file mode 100644 index 0000000000..be6aaf94aa --- /dev/null +++ b/test/permute_scale/CMakeLists.txt @@ -0,0 +1,6 @@ +add_custom_target(test_permute) +add_gtest_executable(test_permute_scale test_permute_scale.cpp) +if(result EQUAL 0) + target_link_libraries(test_permute_scale PRIVATE utility device_permute_scale_instance) + add_dependencies(test_permute test_permute_scale) +endif() diff --git a/test/permute_scale/test_permute_scale.cpp b/test/permute_scale/test_permute_scale.cpp new file mode 100644 index 0000000000..518d3fc87a --- /dev/null +++ b/test/permute_scale/test_permute_scale.cpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "test_permute_scale_impl.hpp" + +using F16 = ck::half_t; +using F32 = float; +using ck::index_t; + +template +class TestPermute : public ::testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + + void Run() + { + std::vector> lengths = { + {4, 2, 1, 8}, {1, 1, 1, 1}, {16, 8, 32, 64}, {32, 64, 128, 128}}; + + for(auto length : lengths) + { + bool success = + ck::test_permute_scale_impl(true, 2, false, false, length); + EXPECT_TRUE(success); + } + } +}; + +using KernelTypes = ::testing::Types, std::tuple>; + +TYPED_TEST_SUITE(TestPermute, KernelTypes); +TYPED_TEST(TestPermute, Test_FP16) { this->Run(); } +TYPED_TEST(TestPermute, Test_FP32) { this->Run(); } diff --git a/test/permute_scale/test_permute_scale_impl.hpp b/test/permute_scale/test_permute_scale_impl.hpp new file mode 100644 index 0000000000..3837e7ef5a --- /dev/null +++ b/test/permute_scale/test_permute_scale_impl.hpp @@ -0,0 +1,212 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" + +#include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +namespace ck { +template +void host_elementwise4D(HostTensorB& B_nhwc, + const HostTensorA& A_nchw, + FunctorA functor_a, + FunctorB functor_b, + float scale) +{ + std::size_t N = A_nchw.mDesc.GetLengths()[0]; + std::size_t C = A_nchw.mDesc.GetLengths()[1]; + std::size_t H = A_nchw.mDesc.GetLengths()[2]; + std::size_t W = A_nchw.mDesc.GetLengths()[3]; + for(std::size_t w = 0; w < W; ++w) + for(std::size_t h = 0; h < H; ++h) + for(std::size_t c = 0; c < C; ++c) + for(std::size_t n = 0; n < N; ++n) + { + using tmp_type = ck::remove_reference_t; + tmp_type tmp_val = 0; + auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; + functor_b(tmp_val, a_val); + functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], + scale * tmp_val); + } +} + +template +bool test_permute_scale_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector lengths) +{ + bool pass = true; + + using ElementOp = ck::tensor_operation::element_wise::PassThrough; + using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; + using Scale = ck::tensor_operation::element_wise::Scale; + float scale = 2.f; + + index_t N = lengths[0]; + index_t C = lengths[1]; + index_t H = lengths[2]; + index_t W = lengths[3]; + + std::vector nchw = {N, C, H, W}; + std::vector nhwc = {N, H, W, C}; + Tensor a(nchw); + Tensor b(nhwc); + Tensor host_b(nhwc); + + std::array ab_lengths; + + std::array a_strides = {1, + static_cast(nchw[0]), + static_cast(nchw[0] * nchw[1]), + static_cast(nchw[0] * nchw[1] * nchw[2])}; + + std::array b_strides = {1, + static_cast(nhwc[0] * nhwc[1] * nhwc[2]), + static_cast(nhwc[0]), + static_cast(nhwc[0] * nhwc[1])}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::cout << "A: " << a.mDesc << std::endl; + std::cout << "B: " << b.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: a.GenerateTensorValue(GeneratorTensor_2{-1, 2}); break; + default: // a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0} + std::mt19937 gen(11939); + std::uniform_int_distribution dis(0, 1); + auto i = 0; + for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w) + for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h) + for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c) + for(std::size_t n = 0; n < a.mDesc.GetLengths()[0]; ++n) + { + a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) + + (h * nchw[3]) + w] = i; + i = dis(gen); + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + using DeviceOp = ck::tensor_operation::device::DeviceElementwise, + ck::Tuple, + ElementOp, + UnaryOp, + Scale, + NumDim>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + if(do_verification) + { + host_elementwise4D(host_b, a, ElementOp{}, UnaryOp{}, scale); + } + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + ElementOp{}, + UnaryOp{}, + Scale{scale}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + b_device_buf.SetZero(); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + + if(do_verification) + { + b_device_buf.FromDevice(b.mData.data()); + + pass &= ck::utils::check_err( + b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b.mData, ",") << std::endl; + } + } + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_instance_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + if(time_kernel) + { + LogRange(std::cout << "length = ", lengths, ",") << ", "; + std::cout << "best perf = " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_instance_name << std::endl; + } + + return true; +} + +} // namespace ck From 3726a1730e60242a27770c2153aa76cf3da75fb2 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 19 Dec 2023 07:15:24 -0800 Subject: [PATCH 29/75] add -Wno-pass-failed compiler flag (#1105) --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e4b9d8d4b..6fc22b18af 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,6 +61,7 @@ endif() #for f8/bf8_t type add_compile_options(-Wno-bit-int-extension) +add_compile_options(-Wno-pass-failed) if(DL_KERNELS) add_definitions(-DDL_KERNELS) From 3ab1838fb0614a8bbd81e99cecae1db81c0e8679 Mon Sep 17 00:00:00 2001 From: Jun Liu Date: Tue, 19 Dec 2023 07:16:49 -0800 Subject: [PATCH 30/75] ROCm 6.0 replaces all __HIP_PLATFORM_HCC__ with __HIP_PLATFORM_AMD__ (#1106) * ROCm 6.0 replaces all __HIP_PLATFORM_HCC__ with __HIP_PLATFORM_AMD__ * make it backward compatible * Update .clang-tidy * Update ClangTidy.cmake --- .clang-tidy | 2 +- CMakeLists.txt | 6 +++++- cmake/ClangTidy.cmake | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.clang-tidy b/.clang-tidy index 5c2b781687..3815c654fe 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,3 +1,3 @@ CheckOptions: - key: bugprone-reserved-identifier.AllowedIdentifiers - value: '__HIP_PLATFORM_HCC__;__HIP_ROCclr__' + value: '__HIP_PLATFORM_HCC__;__HIP_PLATFORM_AMD__;__HIP_ROCclr__' diff --git a/CMakeLists.txt b/CMakeLists.txt index 6fc22b18af..d78e887efb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -244,7 +244,11 @@ if( DEFINED CK_OVERRIDE_HIP_VERSION_PATCH ) endif() message(STATUS "Build with HIP ${HIP_VERSION}") link_libraries(hip::device) -add_compile_definitions(__HIP_PLATFORM_HCC__=1) +if(CK_hip_VERSION VERSION_GREATER_EQUAL 6.0.23494) + add_compile_definitions(__HIP_PLATFORM_AMD__=1) +else() + add_compile_definitions(__HIP_PLATFORM_HCC__=1) +endif() ## tidy include(EnableCompilerWarnings) diff --git a/cmake/ClangTidy.cmake b/cmake/ClangTidy.cmake index 01b348c458..cf77991a64 100644 --- a/cmake/ClangTidy.cmake +++ b/cmake/ClangTidy.cmake @@ -149,7 +149,7 @@ function(clang_tidy_check TARGET) add_custom_target(${tidy_target} # for some targets clang-tidy not able to get information from .clang-tidy DEPENDS ${SOURCE} - COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml" + COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_PLATFORM_AMD__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml" WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} COMMENT "clang-tidy: Running clang-tidy on target ${SOURCE}..." ) From a167e3c74420d9b0dc7d7f44415700e55b04412f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 19 Dec 2023 07:17:27 -0800 Subject: [PATCH 31/75] Bump rocm-docs-core from 0.30.1 to 0.30.2 in /docs/sphinx (#1104) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.30.1 to 0.30.2. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.30.1...v0.30.2) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 0a65ffc81a..2d49920398 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.30.1 +rocm-docs-core==0.30.2 sphinxcontrib-bibtex==2.6.1 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 75863c214e..7de078ba57 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -113,7 +113,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.30.1 +rocm-docs-core==0.30.2 # via -r requirements.in six==1.16.0 # via From b305a29e4b0053afcff4671f5ee3c84f0540af38 Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 19 Dec 2023 23:45:38 +0800 Subject: [PATCH 32/75] Remove index tensor in avgpool (#1093) * Remove index tensor * fix syntax --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin --- client_example/19_pool/avg_pool3d_fwd.cpp | 63 +++++++++++------------ 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/client_example/19_pool/avg_pool3d_fwd.cpp b/client_example/19_pool/avg_pool3d_fwd.cpp index db8e0569d7..6739a41b2f 100644 --- a/client_example/19_pool/avg_pool3d_fwd.cpp +++ b/client_example/19_pool/avg_pool3d_fwd.cpp @@ -94,7 +94,6 @@ int main(int argc, char* argv[]) SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size); SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size); - SimpleDeviceMem out_indices_device_buf(sizeof(IndexDataType) * out_tensor_size); using DeviceOp = ck::tensor_operation::device::DevicePoolFwdMakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - static_cast(out_indices_device_buf.GetDeviceBuffer()), - in_length, - window_spatial_lengths, - out_length, - in_tensor_stride, - out_tensor_stride, - out_tensor_stride, - window_strides, - window_dilations, - input_left_pads, - input_right_pads, - {2, 3, 4}); + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + nullptr, + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {2, 3, 4}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -184,21 +183,21 @@ int main(int argc, char* argv[]) std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - static_cast(out_indices_device_buf.GetDeviceBuffer()), - in_length, - window_spatial_lengths, - out_length, - in_tensor_stride, - out_tensor_stride, - out_tensor_stride, - window_strides, - window_dilations, - input_left_pads, - input_right_pads, - {2, 3, 4}); + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + nullptr, + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {2, 3, 4}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); From fb5bd51b42e68decf2a5f17cf10dd2f97f890d11 Mon Sep 17 00:00:00 2001 From: Artur Wojcik Date: Wed, 20 Dec 2023 23:34:53 +0100 Subject: [PATCH 33/75] enable compilation of INSTANCES_ONLY for Windows (#1082) * enable compilation of INSTANCES_ONLY for Windows * suppress ROCMChecks warnings on GoogleTests * suppress -Wfloat-equal warning on GoogleTests --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- .gitignore | 9 +++ CMakeLists.txt | 34 +++++---- cmake/getopt.cmake | 28 ++++++++ cmake/googletest.cmake | 50 ------------- cmake/gtest.cmake | 71 +++++++++++++++++++ .../element/unary_element_wise_operation.hpp | 5 ++ ...elementwise_layernorm_welford_variance.hpp | 4 +- .../tensor_operation/gpu/warp/wmma_gemm.hpp | 2 +- .../gpu/CMakeLists.txt | 1 - .../gpu/softmax/CMakeLists.txt | 4 +- library/src/utility/CMakeLists.txt | 10 +-- profiler/src/CMakeLists.txt | 2 +- test/CMakeLists.txt | 11 ++- 13 files changed, 149 insertions(+), 82 deletions(-) create mode 100644 cmake/getopt.cmake delete mode 100644 cmake/googletest.cmake create mode 100644 cmake/gtest.cmake diff --git a/.gitignore b/.gitignore index 340f11cbd2..090594a8df 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,12 @@ _static/ _templates/ _toc.yml _doxygen/ + +# JetBrains IDE +.idea/ +cmake-build*/ +build*/ + +# Python virtualenv +.venv/ + diff --git a/CMakeLists.txt b/CMakeLists.txt index d78e887efb..240832998d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,22 +4,27 @@ if(POLICY CMP0140) cmake_policy(SET CMP0140 NEW) endif() +get_property(_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG) + # This has to be initialized before the project() command appears # Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE -if( NOT MSVC_IDE AND NOT CMAKE_BUILD_TYPE ) - set( CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel." ) +if(_GENERATOR_IS_MULTI_CONFIG) + set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo;MinSizeRel" CACHE STRING + "Available build types (configurations) on multi-config generators") +else() + set(CMAKE_BUILD_TYPE Release CACHE STRING + "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel.") endif() # Default installation path -if(WIN32) - set(CMAKE_INSTALL_PREFIX "/opt/rocm/x86_64-w64-mingw32" CACHE PATH "") -else() +if(NOT WIN32) set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "") endif() set(version 1.1.0) # Check support for CUDA/HIP in Cmake -project(composable_kernel VERSION ${version}) +project(composable_kernel VERSION ${version} LANGUAGES CXX) +include(CTest) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") @@ -73,15 +78,15 @@ if(INSTANCES_ONLY) set(CK_ENABLE_INSTANCES_ONLY "ON") endif() +include(getopt) + # CK config file to record supported datatypes, etc. -configure_file("${PROJECT_SOURCE_DIR}/include/ck/config.h.in" "${PROJECT_BINARY_DIR}/include/ck/config.h") +configure_file(include/ck/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/config.h) # CK version file to record release version as well as git commit hash find_package(Git REQUIRED) execute_process(COMMAND "${GIT_EXECUTABLE}" rev-parse HEAD OUTPUT_VARIABLE COMMIT_ID OUTPUT_STRIP_TRAILING_WHITESPACE) -configure_file("${PROJECT_SOURCE_DIR}/include/ck/version.h.in" "${PROJECT_BINARY_DIR}/include/ck/version.h") - -enable_testing() +configure_file(include/ck/version.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/version.h) set(ROCM_SYMLINK_LIBS OFF) find_package(ROCM REQUIRED PATHS /opt/rocm) @@ -97,7 +102,7 @@ include(TargetFlags) rocm_setup_version(VERSION ${version}) -list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) +list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip "$ENV{ROCM_PATH}" "$ENV{HIP_PATH}") message("GPU_TARGETS= ${GPU_TARGETS}") @@ -142,7 +147,7 @@ find_package(hip) # SWDEV-413293 and https://reviews.llvm.org/D155213 math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}") message("hip_version_flat=${hip_VERSION_FLAT}") -if(${hip_VERSION_FLAT} GREATER 500723302) +if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500723302) message("Adding the fno-offload-uniform-block compiler flag") add_compile_options(-fno-offload-uniform-block) endif() @@ -195,7 +200,6 @@ find_package(Threads REQUIRED) link_libraries(Threads::Threads) ## C++ -enable_language(CXX) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) @@ -466,7 +470,9 @@ if(NOT DEFINED INSTANCES_ONLY) PACKAGE_NAME examples ) add_subdirectory(example) - add_subdirectory(test) + if(BUILD_TESTING) + add_subdirectory(test) + endif() rocm_package_setup_component(profiler LIBRARY_NAME composablekernel diff --git a/cmake/getopt.cmake b/cmake/getopt.cmake new file mode 100644 index 0000000000..dd985ff472 --- /dev/null +++ b/cmake/getopt.cmake @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +add_library(getopt::getopt INTERFACE IMPORTED GLOBAL) + +if(WIN32) + include(FetchContent) + + FetchContent_Declare( + getopt + GIT_REPOSITORY https://github.com/apwojcik/getopt.git + GIT_TAG main + SYSTEM + ) + + set(__build_shared_libs ${BUILD_SHARED_LIBS}) + set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "") + + FetchContent_MakeAvailable(getopt) + + # Restore the old value of BUILD_SHARED_LIBS + set(BUILD_SHARED_LIBS ${__build_shared_libs} CACHE BOOL "Type of libraries to build" FORCE) + + FetchContent_GetProperties(getopt) + + target_link_libraries(getopt::getopt INTERFACE wingetopt) + target_include_directories(getopt::getopt INTERFACE ${getopt_SOURCE_DIR}/src) +endif() \ No newline at end of file diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake deleted file mode 100644 index d6577ac33e..0000000000 --- a/cmake/googletest.cmake +++ /dev/null @@ -1,50 +0,0 @@ -include(FetchContent) - -set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against") - -if(GOOGLETEST_DIR) - set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override") -endif() - -message(STATUS "Fetching GoogleTest") - -list(APPEND GTEST_CMAKE_CXX_FLAGS - -Wno-undef - -Wno-reserved-identifier - -Wno-global-constructors - -Wno-missing-noreturn - -Wno-disabled-macro-expansion - -Wno-used-but-marked-unused - -Wno-switch-enum - -Wno-zero-as-null-pointer-constant - -Wno-unused-member-function - -Wno-comma - -Wno-old-style-cast - -Wno-deprecated - -Wno-unsafe-buffer-usage -) -message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}") - -FetchContent_Declare( - googletest - GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG b85864c64758dec007208e56af933fc3f52044ee -) - -# Will be necessary for windows build -# set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) -FetchContent_GetProperties(googletest) -if(NOT googletest_POPULATED) - FetchContent_Populate(googletest) - add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL) -endif() - -target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) -target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) -target_compile_options(gmock PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) -target_compile_options(gmock_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) - -set_target_properties(gtest PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(gtest_main PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(gmock PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(gmock_main PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake new file mode 100644 index 0000000000..dc840e4e80 --- /dev/null +++ b/cmake/gtest.cmake @@ -0,0 +1,71 @@ +include(FetchContent) + +set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against") + +if(GOOGLETEST_DIR) + set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override") +endif() + +FetchContent_Declare( + GTest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG f8d7d77c06936315286eb55f8de22cd23c188571 + SYSTEM +) + +# Suppress ROCMChecks WARNING on GoogleTests +set(ROCM_DISABLE_CHECKS FALSE) +macro(rocm_check_toolchain_var var access value list_file) + if(NOT ROCM_DISABLE_CHECKS) + _rocm_check_toolchain_var("${var}" "${access}" "${value}" "${list_file}") + endif() +endmacro() + +if(WIN32) + set(gtest_force_shared_crt ON CACHE_INTERNAL "") +endif() + +set(BUILD_GMOCK OFF CACHE INTERNAL "") +set(INSTALL_GTEST OFF CACHE INTERNAL "") + +# Store the current value of BUILD_SHARED_LIBS +set(__build_shared_libs ${BUILD_SHARED_LIBS}) +set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "") + +set(ROCM_DISABLE_CHECKS TRUE) +FetchContent_MakeAvailable(GTest) +set(ROCM_DISABLE_CHECKS FALSE) + +# Restore the old value of BUILD_SHARED_LIBS +set(BUILD_SHARED_LIBS ${__build_shared_libs} CACHE BOOL "Type of libraries to build" FORCE) + +set(BUILD_GMOCK OFF CACHE INTERNAL "") +set(INSTALL_GTEST OFF CACHE INTERNAL "") + +set(GTEST_CXX_FLAGS + -Wno-undef + -Wno-reserved-identifier + -Wno-global-constructors + -Wno-missing-noreturn + -Wno-disabled-macro-expansion + -Wno-used-but-marked-unused + -Wno-switch-enum + -Wno-zero-as-null-pointer-constant + -Wno-unused-member-function + -Wno-comma + -Wno-old-style-cast + -Wno-deprecated + -Wno-unsafe-buffer-usage + -Wno-float-equal +) + +if(WIN32) + list(APPEND GTEST_CXX_FLAGS + -Wno-suggest-destructor-override + -Wno-suggest-override + -Wno-nonportable-system-include-path + -Wno-language-extension-token) +endif() + +target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS}) +target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS}) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index e9c85964c5..eed60caef4 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -174,6 +174,11 @@ struct PassThrough { y = x; } + template <> + __host__ __device__ void operator()(int4_t& y, const int& x) const + { + y = type_convert(x); + } #endif template <> diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp index 3ea72b8534..072275c089 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp @@ -119,7 +119,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk index_t num_k_block_tile_iteration, AccDataType epsilon, const InDataTypePointerTuple p_in_global_tuple, - XDataType* const __restrict__ p_x_lds, + XDataType* const __restrict__ p_x_lds_, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, @@ -149,7 +149,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); auto x_lds_val_buf = make_dynamic_buffer( - p_x_lds, x_grid_desc_m_k.GetElementSpaceSize() / grid_size); + p_x_lds_, x_grid_desc_m_k.GetElementSpaceSize() / grid_size); auto in_thread_buf_tuple = generate_tuple( [&](auto) { diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 979f3567e9..814b4167b8 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -328,7 +328,7 @@ struct WmmaSelector } #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> - static constexpr auto GetWmma() + static constexpr auto GetWmma() { return WmmaInstr::wmma_i32_16x16x16_iu4; } diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index ac01c1b416..0a12e1c49e 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -152,7 +152,6 @@ ENDFOREACH() if(CK_DEVICE_OTHER_INSTANCES) add_library(device_other_operations STATIC ${CK_DEVICE_OTHER_INSTANCES}) add_library(composablekernels::device_other_operations ALIAS device_other_operations) - target_compile_features(device_other_operations PUBLIC) set_target_properties(device_other_operations PROPERTIES POSITION_INDEPENDENT_CODE ON) target_include_directories(device_other_operations PUBLIC $ diff --git a/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt index dbe3764115..6daaec738a 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt @@ -1,5 +1,4 @@ -set(DEVICE_SOFTMAX_INSTANCES) -list(APPEND DEVICE_SOFTMAX_INSTANCES +add_instance_library(device_softmax_instance device_softmax_f16_f16_instance_rank3_reduce1.cpp device_softmax_f16_f16_instance_rank3_reduce2.cpp device_softmax_f16_f16_instance_rank3_reduce3.cpp @@ -14,4 +13,3 @@ list(APPEND DEVICE_SOFTMAX_INSTANCES device_softmax_f32_f32_instance_rank4_reduce2.cpp device_softmax_f32_f32_instance_rank4_reduce3.cpp device_softmax_f32_f32_instance_rank4_reduce4.cpp) -add_instance_library(device_softmax_instance ${DEVICE_SOFTMAX_INSTANCES}) diff --git a/library/src/utility/CMakeLists.txt b/library/src/utility/CMakeLists.txt index 7f6a59eebe..296e6c993a 100644 --- a/library/src/utility/CMakeLists.txt +++ b/library/src/utility/CMakeLists.txt @@ -1,17 +1,19 @@ -## utility -set(UTILITY_SOURCE +add_library(utility STATIC device_memory.cpp host_tensor.cpp convolution_parameter.cpp ) -add_library(utility STATIC ${UTILITY_SOURCE}) add_library(composable_kernel::utility ALIAS utility) - +set_target_properties(utility PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_compile_options(utility PRIVATE ${CMAKE_COMPILER_WARNINGS}) target_include_directories(utility PUBLIC "$" "$" ) +if(WIN32) + target_compile_definitions(utility PUBLIC NOMINMAX) +endif() rocm_install( TARGETS utility diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 7674b3b4f0..5144785aa0 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -58,7 +58,7 @@ set(PROFILER_EXECUTABLE ckProfiler) add_executable(${PROFILER_EXECUTABLE} ${PROFILER_SOURCES}) target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 94c5f2750f..90140659f6 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -3,7 +3,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/profiler/include ) -include(googletest) +include(gtest) add_custom_target(tests) @@ -50,6 +50,7 @@ function(add_test_executable TEST_NAME) #only continue if there are some source files left on the list if(ARGN) add_executable(${TEST_NAME} ${ARGN}) + target_link_libraries(${TEST_NAME} PRIVATE getopt::getopt) add_test(NAME ${TEST_NAME} COMMAND $) add_dependencies(tests ${TEST_NAME}) add_dependencies(check ${TEST_NAME}) @@ -58,9 +59,7 @@ function(add_test_executable TEST_NAME) endif() #message("add_test returns ${result}") set(result ${result} PARENT_SCOPE) -endfunction(add_test_executable TEST_NAME) - -include(GoogleTest) +endfunction() function(add_gtest_executable TEST_NAME) message("adding gtest ${TEST_NAME}") @@ -109,14 +108,14 @@ function(add_gtest_executable TEST_NAME) # suppress gtest warnings target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef) - target_link_libraries(${TEST_NAME} PRIVATE gtest_main) + target_link_libraries(${TEST_NAME} PRIVATE gtest_main getopt::getopt) add_test(NAME ${TEST_NAME} COMMAND $) rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) set(result 0) endif() #message("add_gtest returns ${result}") set(result ${result} PARENT_SCOPE) -endfunction(add_gtest_executable TEST_NAME) +endfunction() add_subdirectory(magic_number_division) add_subdirectory(space_filling_curve) From 78eb3f0b46aafc52c6d19a07b9dc5bd19b8e7807 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 20 Dec 2023 14:35:25 -0800 Subject: [PATCH 34/75] Bump rocm-docs-core from 0.30.2 to 0.30.3 in /docs/sphinx (#1107) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.30.2 to 0.30.3. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.30.2...v0.30.3) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 2d49920398..6bcd2c43de 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.30.2 +rocm-docs-core==0.30.3 sphinxcontrib-bibtex==2.6.1 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 7de078ba57..e705e35e13 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -113,7 +113,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.30.2 +rocm-docs-core==0.30.3 # via -r requirements.in six==1.16.0 # via From 20b1ae7cedcab951ff9499d5cf812176cf71b7e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Sat, 23 Dec 2023 22:12:49 +0100 Subject: [PATCH 35/75] Fix results verify in test_tensor (#1109) --- test/wrapper/test_tensor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/wrapper/test_tensor.cpp b/test/wrapper/test_tensor.cpp index 92f8e2e1bd..74cf7f1316 100644 --- a/test/wrapper/test_tensor.cpp +++ b/test/wrapper/test_tensor.cpp @@ -127,7 +127,7 @@ __global__ void TestTensorReadWriteDevice(void* data, void* success) StaticInitTensor(tensor_vgpr); StaticInitTensor(tensor_sgpr); - *casted_success_ptr &= TestTensorCheck1d(tensor_global); + *casted_success_ptr = TestTensorCheck1d(tensor_global); *casted_success_ptr &= TestTensorCheck3d(tensor_global); *casted_success_ptr &= TestTensorCheck1d(tensor_lds); From a35e466c86cbb513d7900825d7dca4698541808b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 2 Jan 2024 11:36:45 +0100 Subject: [PATCH 36/75] Revert "[SWDEV-435347] disable instances failed with mainlien compiler (#1077)" (#1101) This reverts commit ff24b537cb5412b4720a8923bbc090de6d020a3b. --- ...rouped_convolution_forward_scaleadd_ab.hpp | 43 +++++------ ..._ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 73 +++++++++---------- 2 files changed, 56 insertions(+), 60 deletions(-) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp index 348bcaef8a..1bea403afa 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp @@ -23,20 +23,19 @@ using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; #ifdef CK_ENABLE_BF16 // grouped conv3d forward multi AB scaleadd, NDHWGC/GKZYXC/NDHWGK -// TODO: Workaround for https://ontrack-internal.amd.com/browse/SWDEV-435347 -// void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( -// std::vector, -// NDHWGK, -// ck::Tuple, -// ck::Tuple, -// ck::Tuple<>, -// BF16, -// ScaleAdd, -// ScaleAdd, -// PassThrough>>>& instances); +void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -152,15 +151,13 @@ struct DeviceOperationInstanceFactory> && - // is_same_v> && - // is_same_v && is_same_v) - // { - // add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( - // op_ptrs); - // } + if constexpr(is_same_v> && + is_same_v> && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v> && diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index d5b9da86c1..c7801f02ce 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -9,43 +9,42 @@ namespace tensor_operation { namespace device { namespace instance { -// TODO: Workaround for https://ontrack-internal.amd.com/browse/SWDEV-435347 -// void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( -// std::vector, -// NDHWGK, -// ck::Tuple, -// ck::Tuple, -// ck::Tuple<>, -// BF16, -// ScaleAdd, -// ScaleAdd, -// PassThrough>>>& instances) -// { -// add_device_operation_instances( -// instances, -// device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, -// NDHWGC, -// GKZYXC, -// NDHWGK, -// ConvFwdDefault>{}); -// add_device_operation_instances( -// instances, -// device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, -// NDHWGC, -// GKZYXC, -// NDHWGK, -// ConvFwd1x1P0>{}); -// add_device_operation_instances( -// instances, -// device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, -// NDHWGC, -// GKZYXC, -// NDHWGK, -// ConvFwd1x1S1P0>{}); -// } +void add_device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + ScaleAdd, + ScaleAdd, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scaleadd_ab_bf16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvFwd1x1S1P0>{}); +} } // namespace instance } // namespace device From 0e07dfdeabfc0e7922f56b1ea1ed910a38633434 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 2 Jan 2024 14:00:53 -0800 Subject: [PATCH 37/75] change the googletest cmake syntax for older cmake versions (#1116) --- cmake/gtest.cmake | 1 - 1 file changed, 1 deletion(-) diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake index dc840e4e80..0915f53411 100644 --- a/cmake/gtest.cmake +++ b/cmake/gtest.cmake @@ -10,7 +10,6 @@ FetchContent_Declare( GTest GIT_REPOSITORY https://github.com/google/googletest.git GIT_TAG f8d7d77c06936315286eb55f8de22cd23c188571 - SYSTEM ) # Suppress ROCMChecks WARNING on GoogleTests From b268f273de0d60d8c385e1e88d3b1cfaa18f3d85 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 2 Jan 2024 14:01:12 -0800 Subject: [PATCH 38/75] adding -Wno-switch-default compiler flag (#1115) --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 240832998d..a65c90e15d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,6 +67,7 @@ endif() #for f8/bf8_t type add_compile_options(-Wno-bit-int-extension) add_compile_options(-Wno-pass-failed) +add_compile_options(-Wno-switch-default) if(DL_KERNELS) add_definitions(-DDL_KERNELS) From 4234b3a6910283a4fec04cc6cd5541c05b01c2fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 3 Jan 2024 01:10:57 +0100 Subject: [PATCH 39/75] Add tensor partition and generic copy for ck wrapper (#1108) * Add tensor partition and generic copy for ck wrapper * Update changelog * Stylistic fixes * Change shape/strides logic to descriptor transforms * Fixes * Fix client example * Fix comments --- CHANGELOG.md | 2 +- docs/wrapper.rst | 8 + include/ck/utility/tuple_helper.hpp | 11 + include/ck/wrapper/layout.hpp | 120 ++------ include/ck/wrapper/operations/copy.hpp | 41 +++ include/ck/wrapper/tensor.hpp | 222 +++++++------- include/ck/wrapper/utils/layout_utils.hpp | 157 +++++++--- include/ck/wrapper/utils/tensor_partition.hpp | 285 ++++++++++++++++++ include/ck/wrapper/utils/tensor_utils.hpp | 110 +++---- test/wrapper/CMakeLists.txt | 4 + test/wrapper/test_copy.cpp | 129 ++++++++ test/wrapper/test_layout.cpp | 11 +- test/wrapper/test_partition.cpp | 119 ++++++++ test/wrapper/test_tensor.cpp | 27 +- 14 files changed, 940 insertions(+), 306 deletions(-) create mode 100644 include/ck/wrapper/operations/copy.hpp create mode 100644 include/ck/wrapper/utils/tensor_partition.hpp create mode 100644 test/wrapper/test_copy.cpp create mode 100644 test/wrapper/test_partition.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 2891b8585b..abca69142e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ None - Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) - Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) - Support for Batched Gemm DL (#732) -- Introduce wrapper sublibrary (limited functionality). (#1071, #1098) +- Introduce wrapper sublibrary (limited functionality). (#1071, #1098, #1108) ### Changes - Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) diff --git a/docs/wrapper.rst b/docs/wrapper.rst index a2f60b97ae..da3a79eda8 100644 --- a/docs/wrapper.rst +++ b/docs/wrapper.rst @@ -71,3 +71,11 @@ Tensor helpers ------------------------------------- .. doxygenfile:: tensor_utils.hpp + +.. doxygenfile:: tensor_partition.hpp + +------------------------------------- +Operations +------------------------------------- + +.. doxygenfile:: copy.hpp diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index 75f2693f20..f365230054 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -178,4 +178,15 @@ __host__ __device__ constexpr auto TupleDepth(const Tuple&) return math::max(TupleDepth(Ts{})...); } +template +__host__ __device__ constexpr auto TupleSlice(const Tuple& tuple) +{ + return generate_tuple( + [&](auto i) { + using Idx = Number; + return tuple.At(Idx{}); + }, + Number{}); +} + } // namespace ck diff --git a/include/ck/wrapper/layout.hpp b/include/ck/wrapper/layout.hpp index f20d985b49..1643eb7383 100644 --- a/include/ck/wrapper/layout.hpp +++ b/include/ck/wrapper/layout.hpp @@ -14,11 +14,9 @@ namespace wrapper { * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t * (dynamic layout). It is possible to pass nested shapes * (e.g. ((4, 2), 2)), nested dimensions are merged. - * \tparam Strides Tuple of Number<> (for compile-time layout) or index_t - * (dynamic layout). Stride tuple should be nested if shape tuple is - * nested. + * \tparam UnnestedDescriptorType Tensor descriptor for unnested shape dims. */ -template +template struct Layout { private: @@ -31,7 +29,7 @@ struct Layout { return generate_tuple( [&](auto) { - if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime()) + if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime()) { // runtime layout return index_t(0); @@ -45,27 +43,6 @@ struct Layout Number::Size()>{}); } - // Generate packed (column-major) strides if not passed - template - __host__ __device__ constexpr static auto - GenerateColumnMajorPackedStrides(const Tuple& shape) - { - const auto unrolled_shape = UnrollNestedTuple(shape); - return generate_tuple( - [&](auto i) { - if constexpr(i.value == 0) - { - return I1; - } - else - { - return TupleReduce([](auto x, auto y) { return x * y; }, - unrolled_shape); - } - }, - Number{}); - } - // Generate LowerDims in Compile-time for MergeTrasform using passed Type // If element of Tuple is also tuple, then merge (generate sequence for merge) // If tuple is element, then pass through (sequence with one element) @@ -207,33 +184,15 @@ struct Layout return transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims); } - template - __host__ __device__ static auto MakeFlattenDescriptor(const LayoutShape& shape, - const LayoutStrides& strides) - { - const auto unrolled_shape = UnrollNestedTuple(shape); - const auto unrolled_strides = UnrollNestedTuple(strides); - static_assert(unrolled_shape.Size() == unrolled_strides.Size(), - "Size of strides and shape are not consistent."); - return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); - } - - // If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`. - using DeducedStrides = - std::conditional_t>, - remove_cvref_t, - Strides>; - using FlattenDescriptorType = - remove_cvref_t; using Descriptor1dType = - remove_cvref_t; + remove_cvref_t; using DefaultIdxsTupleType = remove_cvref_t; template __host__ __device__ constexpr static auto TransformDesc(const Tuple& shape, const Tuple& idx, - const FlattenDescriptorType& naive_descriptor) + const UnnestedDescriptorType& naive_descriptor) { if constexpr(Tuple::Size() == I1) { @@ -256,48 +215,33 @@ struct Layout } using MergedNestsDescriptorType = remove_cvref_t; + Shape{}, DefaultIdxsTupleType{}, UnnestedDescriptorType{}))>; public: __host__ __device__ constexpr auto GetElementSpaceSize() const { - return flatten_descriptor_.GetElementSpaceSize(); + return unnested_descriptor_.GetElementSpaceSize(); } __host__ __device__ Layout() = delete; + /** * \brief Layout constructor. * * \param shape Shape for layout. - * \param strides Strides for layout (optional if tensor is packed). + * \param unnested_descriptor Descriptor */ - __host__ __device__ constexpr Layout(const Shape& shape, const Strides& strides) - : flatten_descriptor_{}, shape_(shape), strides_(strides) + __host__ __device__ constexpr Layout(const Shape& shape, + const UnnestedDescriptorType& unnested_descriptor) + : shape_(shape) { // Construct if runtime mode - if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime()) + if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime()) { - flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_); - descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_); + unnested_descriptor_ = unnested_descriptor; + descriptor_1d_ = MakeMerge1d(shape_, unnested_descriptor_); merged_nests_descriptor_ = - TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_); - } - } - - /** - * \brief Layout constructor (with default packed column-major strides). - * - * \param shape Shape for layout. - */ - __host__ __device__ constexpr Layout(const Shape& shape) - : flatten_descriptor_{}, shape_(shape), strides_(GenerateColumnMajorPackedStrides(shape_)) - { - if constexpr(!FlattenDescriptorType::IsKnownAtCompileTime()) - { - flatten_descriptor_ = MakeFlattenDescriptor(shape_, strides_); - descriptor_1d_ = MakeMerge1d(shape_, flatten_descriptor_); - merged_nests_descriptor_ = - TransformDesc(shape_, DefaultIdxsTupleType{}, flatten_descriptor_); + TransformDesc(shape_, DefaultIdxsTupleType{}, unnested_descriptor_); } } @@ -310,9 +254,9 @@ struct Layout template __host__ __device__ constexpr index_t operator()() const { - static_assert(FlattenDescriptorType::IsKnownAtCompileTime(), + static_assert(UnnestedDescriptorType::IsKnownAtCompileTime(), "Compiletime operator used on runtime layout."); - using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, FlattenDescriptorType{})); + using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnnestedDescriptorType{})); using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{})); return TransformedDesc{}.CalculateOffset(UnrolledIdx{}); } @@ -339,7 +283,7 @@ struct Layout else { // Custom index, need to transform descriptor - const auto transformed_desc = TransformDesc(shape_, Idx, flatten_descriptor_); + const auto transformed_desc = TransformDesc(shape_, Idx, unnested_descriptor_); return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx)); } } @@ -351,7 +295,7 @@ struct Layout * \return Calculated size. */ template - __host__ __device__ constexpr index_t GetLength() const + __host__ __device__ constexpr auto GetLength() const { const auto elem = shape_.At(Number{}); if constexpr(is_detected>::value) @@ -371,7 +315,7 @@ struct Layout * * \return Calculated size. */ - __host__ __device__ constexpr index_t GetLengths() const + __host__ __device__ constexpr auto GetLengths() const { const auto unrolled_shape = UnrollNestedTuple(shape_); return TupleReduce([](auto x, auto y) { return x * y; }, @@ -385,13 +329,6 @@ struct Layout */ __host__ __device__ constexpr const Shape& GetShape() const { return shape_; } - /** - * \brief Strides getter. - * - * \return Strides. - */ - __host__ __device__ constexpr const DeducedStrides& GetStrides() const { return strides_; } - /** * \brief Get default lengths (tuple filled with Shape length elements). * @@ -417,17 +354,26 @@ struct Layout * * \return Default descriptor. */ - __host__ __device__ constexpr MergedNestsDescriptorType GetDefaultDescriptor() + __host__ __device__ constexpr const MergedNestsDescriptorType& GetDefaultDescriptor() const { return merged_nests_descriptor_; } + /** + * \brief Get unnested descriptor (with unrolled dims) + * + * \return Flatten descriptor. + */ + __host__ __device__ constexpr const UnnestedDescriptorType& GetUnnestedDescriptor() const + { + return unnested_descriptor_; + } + private: - FlattenDescriptorType flatten_descriptor_; + UnnestedDescriptorType unnested_descriptor_; Descriptor1dType descriptor_1d_; MergedNestsDescriptorType merged_nests_descriptor_; const Shape shape_; - const DeducedStrides strides_; }; } // namespace wrapper diff --git a/include/ck/wrapper/operations/copy.hpp b/include/ck/wrapper/operations/copy.hpp new file mode 100644 index 0000000000..aec80f9ca7 --- /dev/null +++ b/include/ck/wrapper/operations/copy.hpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "../utils/tensor_utils.hpp" + +namespace ck { +namespace wrapper { + +/** + * \brief Perform generic copy between two tensors. Tensors must have the + * same size. + * + * \param src_tensor Source tensor. + * \param dst_tensor Destination tensor. + */ +template +__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) +{ + if constexpr(!SrcTensorType::IsDynamicBuffer) + { + using SizeType = decltype(size(src_tensor)); + static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); }); + } + else if constexpr(!DstTensorType::IsDynamicBuffer) + { + using SizeType = decltype(size(dst_tensor)); + static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); }); + } + else + { + for(int i = 0; i < size(src_tensor); i++) + { + dst_tensor(i) = src_tensor(i); + } + } +} + +} // namespace wrapper +} // namespace ck diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index 4ec6498fbc..a363641373 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "utils/tensor_utils.hpp" +#include "utils/tensor_partition.hpp" #include "utils/layout_utils.hpp" namespace ck { @@ -15,14 +16,14 @@ namespace wrapper { * \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR). * \tparam ElementType Element data type. * \tparam Shape Tensor shape (layout component). - * \tparam Strides Tensor strides (layout component). + * \tparam UnnestedDescriptorType Unnested descriptor (layout component). * \tparam NumVectors Number of vectors (only for VGPR, SGPR). * \tparam ScalarPerVector Scalars per vector (only for VGPR, SGPR). */ template @@ -31,50 +32,20 @@ struct Tensor private: // Check if Tuple contains Slice object template - constexpr static bool IsSlicing(T&&) + __host__ __device__ constexpr static bool IsSlicing(T&&) { return is_detected::value; } template - constexpr static bool IsSlicing(Tuple&&) + __host__ __device__ constexpr static bool IsSlicing(Tuple&&) { return (IsSlicing(Ts{}) || ...); } - // Calculate first index of new tensor after slice - // It is needed to calculate offset for new tensor - template - constexpr auto GetStartIdxForSlicedTensor(const Tuple& idx) const - { - const auto start_idx_for_sliced_tensor = generate_tuple( - [&](auto i) { - constexpr auto num_i = Number{}; - if constexpr(is_detected>>::value) - { - // if tuple then recurrence - return GetStartIdxForSlicedTensor(idx.At(num_i)); - } - else if constexpr(is_detected>>::value) - { - // if slice, return the beginning of the interval - return idx.At(num_i).from_; - } - else - { - // if one dim selected - return idx.At(num_i); - } - }, - Number::Size()>{}); - - return start_idx_for_sliced_tensor; - } - // Calculate new tensor shape after slice template - constexpr auto GetShapeFromSlicedTensor(const Tuple& idx, - const ShapeTmpType& shape) const + __host__ __device__ constexpr auto GetShapeFromSlicedTensor(const Tuple& idx, + const ShapeTmpType& shape) const { // Pack each value in tuple to remove empty tuples after generation auto new_shape = generate_tuple( @@ -112,67 +83,137 @@ struct Tensor return UnrollNestedTuple<0, 1>(new_shape); } - template - constexpr auto GetStridesFromSlicedTensor(const Tuple& idx, - const StridesTmpType& strides) const + // Generate Freeze for each of nested shape + template + __host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, + const ShapeTmpType& shape) const + { + const auto unrolled_shape = UnrollNestedTuple(shape); + return generate_tuple( + [&](auto i) { + // dimension offset from idx + const auto dim = unrolled_shape.At(Number{}); + const auto dim_idx = idx % dim; + idx /= dim; + return make_freeze_transform(dim_idx); + }, + Number{}); + } + + template + __host__ __device__ constexpr auto + GetTransformsFromSlicedTensor(const Tuple& idx, const ShapeTmpType& shape) const { // Pack each value in tuple to remove empty tuples after generation - auto new_strides = generate_tuple( + auto transforms = generate_tuple( [&](auto i) { constexpr auto num_i = Number{}; if constexpr(is_detected>>::value) { - if constexpr(!IsSlicing(tuple_element_t>{})) - { - // if tuple does not have any slice then we can remove dimension - return Tuple<>{}; - } - else - { - // if tuple then recurrence - return make_tuple( - GetStridesFromSlicedTensor(idx.At(num_i), strides.At(num_i))); - } + return GetTransformsFromSlicedTensor(idx.At(num_i), shape.At(num_i)); } else if constexpr(is_detected>>::value) { - // Stride will be the same - return make_tuple(strides.At(num_i)); + + const auto from = idx.At(num_i).from_; + const auto dim = shape.At(num_i); + const auto range = idx.At(num_i).range(dim); + return make_slice_transform(range, from, from + range); } else { // remove dimension for just value - return Tuple<>{}; + return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i)); } }, Number::Size()>{}); // Remove empty tuples (deleted elements) and return - return UnrollNestedTuple<0, 1>(new_strides); + return UnrollNestedTuple(transforms); + } + + // There is no output for Freeze transform + template + __host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze&) const + { + return Sequence<>{}; + } + + template + __host__ __device__ constexpr auto + GetSequenceVal(const ck::Slice&) const + { + return Sequence{}; + } + + template + __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&) const + { + return Tuple<>{}; + } + + template + __host__ __device__ constexpr auto + GenerateUpperDims(const Tuple& transforms) const + { + constexpr auto num_transforms = Tuple::Size(); + // Deduce Sequence element for specific transform + const auto currect_elem = GetSequenceVal(transforms.At(Number<0>{})); + if constexpr(is_same_v>) + { + const auto next_tuple = GenerateUpperDims(TupleSlice<1, num_transforms>(transforms)); + return concat_tuple(make_tuple(currect_elem), next_tuple); + } + else + { + // Increase i if current_elem is Slice transform + const auto next_tuple = + GenerateUpperDims(TupleSlice<1, num_transforms>(transforms)); + return concat_tuple(make_tuple(currect_elem), next_tuple); + } + } + + template + __host__ __device__ constexpr auto + GetDescriptorFromSlicedTensor(const Tuple& idx, + const ShapeTmpType& shape, + const FlattenDescriptor& flatten_desc) const + { + constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size(); + + const auto transforms = GetTransformsFromSlicedTensor(idx, shape); + using TransformsTupleType = decltype(transforms); + + const auto lower_dims = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){}; + return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); } public: - using ElementSpaceSize = decltype(Layout{ - Shape{}, Strides{}}.GetElementSpaceSize()); // SpaceSize type for buffer - using TensorElementType = ElementType; // DataType + using ElementSpaceSize = decltype(Layout{ + Shape{}, UnnestedDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer + using TensorElementType = ElementType; // DataType static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace; static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr || BufferAddressSpace == MemoryTypeEnum ::Vgpr); __host__ __device__ Tensor() = delete; - __host__ __device__ Tensor(ElementType* pointer, const Layout& layout) + __host__ __device__ Tensor(ElementType* pointer, + const Layout& layout) : layout_(layout), buffer_(make_dynamic_buffer(pointer, layout.GetElementSpaceSize())) { } - __host__ __device__ Tensor(const Layout& layout) : layout_(layout) + __host__ __device__ Tensor(const Layout& layout) + : layout_(layout) { static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register."); } - __host__ __device__ constexpr const Layout& GetLayout() const + __host__ __device__ constexpr const Layout& GetLayout() const { return layout_; } @@ -182,21 +223,14 @@ struct Tensor __host__ __device__ auto operator[](const Tuple& idx) const { static_assert(IsDynamicBuffer, "Register slice is not supported"); - // Calculate offset based on first idx for new tensor - const index_t offset = layout_(GetStartIdxForSlicedTensor(idx)); + const auto& shape = layout_.GetShape(); + auto new_shape = GetShapeFromSlicedTensor(idx, shape); - auto new_shape = GetShapeFromSlicedTensor(idx, layout_.GetShape()); - if constexpr(is_same_v>) - { - auto new_layout = make_layout(new_shape); - return make_tensor(buffer_.p_data_ + offset, new_layout); - } - else - { - auto new_strides = GetStridesFromSlicedTensor(idx, layout_.GetStrides()); - auto new_layout = make_layout(new_shape, new_strides); - return make_tensor(buffer_.p_data_ + offset, new_layout); - } + const auto& flatten_desc = layout_.GetUnnestedDescriptor(); + auto new_desc = GetDescriptorFromSlicedTensor(idx, shape, flatten_desc); + const auto new_layout = + Layout(new_shape, new_desc); + return make_tensor(buffer_.p_data_, new_layout); } template {}), bool> = false> @@ -222,18 +256,10 @@ struct Tensor } else { - if constexpr(is_same_v>) - { - constexpr index_t offset = - Layout{Shape{}}.template operator()>(); - return buffer_[Number{}]; - } - else - { - constexpr index_t offset = - Layout{Shape{}, Strides{}}.template operator()>(); - return buffer_[Number{}]; - } + constexpr index_t offset = Layout{ + Shape{}, + UnnestedDescriptorType{}}.template operator()>(); + return buffer_[Number{}]; } } @@ -260,18 +286,10 @@ struct Tensor } else { - if constexpr(is_same_v>) - { - constexpr index_t offset = - Layout{Shape{}}.template operator()>(); - return buffer_(Number{}); - } - else - { - constexpr index_t offset = - Layout{Shape{}, Strides{}}.template operator()>(); - return buffer_(Number{}); - } + constexpr index_t offset = Layout{ + Shape{}, + UnnestedDescriptorType{}}.template operator()>(); + return buffer_(Number{}); } } @@ -292,6 +310,8 @@ struct Tensor return layout_.GetDefaultDescriptor(); } + __host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; } + private: using DynamicBufferType = DynamicBuffer; - const Layout layout_; + const Layout layout_; Buffer buffer_; }; diff --git a/include/ck/wrapper/utils/layout_utils.hpp b/include/ck/wrapper/utils/layout_utils.hpp index 5df9dd7dea..f4ba0a969f 100644 --- a/include/ck/wrapper/utils/layout_utils.hpp +++ b/include/ck/wrapper/utils/layout_utils.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,11 +22,57 @@ namespace wrapper { // Disable from doxygen docs generation /// @cond // forward declaration -template +template struct Layout; template using is_tuple = decltype(std::declval().IsTuple()); + +namespace { +// Generate packed (column-major) strides if not passed +template +__host__ __device__ constexpr static auto +GenerateColumnMajorPackedStrides(const Tuple& shape) +{ + const auto unrolled_shape = UnrollNestedTuple(shape); + return generate_tuple( + [&](auto i) { + if constexpr(i.value == 0) + { + return Number<1>{}; + } + else + { + return TupleReduce{}.value, i.value>([](auto x, auto y) { return x * y; }, + unrolled_shape); + } + }, + Number{}); +} + +template +__host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shape, + const LayoutStrides& strides) +{ + const auto unrolled_shape = UnrollNestedTuple(shape); + if constexpr(is_same_v>) + { + // if not passed, then generate + const auto unrolled_strides = GenerateColumnMajorPackedStrides(unrolled_shape); + static_assert(unrolled_shape.Size() == unrolled_strides.Size(), + "Size of strides and shape are not consistent."); + return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); + } + else + { + const auto unrolled_strides = UnrollNestedTuple(strides); + static_assert(unrolled_shape.Size() == unrolled_strides.Size(), + "Size of strides and shape are not consistent."); + return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); + } +} +} // namespace + /// @endcond // make_* @@ -38,10 +84,10 @@ using is_tuple = decltype(std::declval().IsTuple()); * \return Constructed layout. */ template -__host__ __device__ constexpr Layout make_layout(const Shape& shape, - const Strides& strides) +__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides) { - return Layout(shape, strides); + using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Strides{})); + return Layout(shape, MakeFlattenDescriptor(shape, strides)); } /** @@ -52,9 +98,10 @@ __host__ __device__ constexpr Layout make_layout(const Shape& sh * \return Constructed layout. */ template -__host__ __device__ constexpr Layout> make_layout(const Shape& shape) +__host__ __device__ constexpr auto make_layout(const Shape& shape) { - return Layout>(shape); + using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Tuple<>{})); + return Layout(shape, MakeFlattenDescriptor(shape, Tuple<>{})); } // Layout helpers @@ -89,26 +136,51 @@ __host__ __device__ constexpr auto get(const Tuple& tuple) * \param layout Layout to create sub layout. * \return Requsted sub layout. */ -template -__host__ __device__ constexpr auto get(const Layout& layout) +template +__host__ __device__ constexpr auto get(const Layout& layout) { - const auto& shape = layout.GetShape(); - const auto& new_shape = get(shape); + const auto& shape = layout.GetShape(); + const auto new_shape = get(shape); static_assert(is_detected::value, "Shape of sub layout must be tuple"); - if constexpr(is_same_v>) - { - // If stride not passed, create without strides - return make_layout(new_shape); - } - else - { - const auto& strides = layout.GetStrides(); - const auto& new_strides = get(strides); - static_assert(is_detected::value, - "Strides of sub layout must be tuple"); - return make_layout(new_shape, new_strides); - } + + constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size(); + constexpr auto new_shape_dims = decltype(UnrollNestedTuple(new_shape))::Size(); + constexpr auto shape_offset = decltype(UnrollNestedTuple(TupleSlice<0, idx>(shape)))::Size(); + + const auto unrolled_shape = UnrollNestedTuple(shape); + const auto transforms = generate_tuple( + [&](auto i) { + // Compare Idx with shape + if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims) + { + // Remove dimension + return make_freeze_transform(Number<0>{}); + } + else + { + return make_pass_through_transform(unrolled_shape.At(i)); + } + }, + Number{}); + + const auto lower_dims = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + const auto upper_dims = generate_tuple( + [&](auto i) { + if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims) + return Sequence<>{}; + + else + { + return Sequence{}; + } + }, + Number{}); + + const auto& flatten_desc = layout.GetUnnestedDescriptor(); + auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); + return Layout(new_shape, new_desc); } /** @@ -142,8 +214,8 @@ __host__ __device__ T constexpr size(const T& dim) * \param layout Layout to get Shape of. * \return Requsted length. */ -template -__host__ __device__ constexpr index_t size(const Layout& layout) +template +__host__ __device__ constexpr auto size(const Layout& layout) { return layout.template GetLength(); } @@ -155,7 +227,7 @@ __host__ __device__ constexpr index_t size(const Layout& layout) * \return Requsted size. */ template -__host__ __device__ constexpr index_t size(const Tuple& shape) +__host__ __device__ constexpr auto size(const Tuple& shape) { const auto unrolled_shape = UnrollNestedTuple(shape); return TupleReduce<0, unrolled_shape.Size()>([](auto x, auto y) { return x * y; }, @@ -168,8 +240,8 @@ __host__ __device__ constexpr index_t size(const Tuple& shape) * \param layout Layout to calculate shape size. * \return Requsted size. */ -template -__host__ __device__ constexpr index_t size(const Layout& layout) +template +__host__ __device__ constexpr auto size(const Layout& layout) { return layout.GetLengths(); } @@ -182,7 +254,7 @@ __host__ __device__ constexpr index_t size(const Layout& layout) * \return Requsted length. */ template -__host__ __device__ constexpr index_t size(const Tuple& tuple) +__host__ __device__ constexpr auto size(const Tuple& tuple) { return size(tuple.At(Number{})); } @@ -208,8 +280,9 @@ __host__ __device__ constexpr auto size(const T& elem) * \param layout Layout to calculate rank. * \return Requsted rank. */ -template -__host__ __device__ constexpr auto rank([[maybe_unused]] const Layout& layout) +template +__host__ __device__ constexpr auto +rank([[maybe_unused]] const Layout& layout) { return Shape::Size(); } @@ -261,8 +334,8 @@ __host__ __device__ constexpr auto rank(const T& elem) * \param layout Layout to calculate depth. * \return Requsted depth. */ -template -__host__ __device__ constexpr auto depth(const Layout& layout) +template +__host__ __device__ constexpr auto depth(const Layout& layout) { const auto& shape = layout.GetShape(); return TupleDepth(shape); @@ -307,26 +380,14 @@ __host__ __device__ constexpr auto depth(const T& elem) return depth(get(elem)); } -/** - * \brief Get Layout strides. - * - * \param layout Layout to get strides from. - * \return Requsted strides. - */ -template -__host__ __device__ constexpr const auto& stride(const Layout& layout) -{ - return layout.GetStrides(); -} - /** * \brief Get Layout shape. * * \param layout Layout to get shape from. * \return Requsted shape. */ -template -__host__ __device__ constexpr const auto& shape(const Layout& layout) +template +__host__ __device__ constexpr const auto& shape(const LayoutType& layout) { return layout.GetShape(); } diff --git a/include/ck/wrapper/utils/tensor_partition.hpp b/include/ck/wrapper/utils/tensor_partition.hpp new file mode 100644 index 0000000000..a0634f6b38 --- /dev/null +++ b/include/ck/wrapper/utils/tensor_partition.hpp @@ -0,0 +1,285 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "tensor_utils.hpp" +#include "layout_utils.hpp" + +namespace ck { +namespace wrapper { + +namespace { +// Calculate shape for partition based on number of threads per each dim and +// previous shape +template +__host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple& shape, + const Tuple& thread_lengths) +{ + static_assert(Tuple::Size() == Tuple::Size(), "Wrong thread_lengths shape."); + return generate_tuple( + [&](auto i) { + constexpr auto num_i = Number{}; + if constexpr(is_detected>>::value) + { + // if tuple then recurrence + return CalculateLocalPartitionShape(shape.At(num_i), thread_lengths.At(num_i)); + } + else + { + const auto slice_len = shape.At(num_i) / thread_lengths.At(num_i); + return slice_len; + } + }, + Number::Size()>{}); +} + +// Calculate shape for partition based on number of threads per each dim, +// previous strides and steps +template +__host__ __device__ constexpr auto +CalculateLocalPartitionDescriptor(const Tuple& shape, + const Tuple& thread_lengths, + const Tuple& steps, + const FlattenDescType& flatten_desc) +{ + + static_assert(Tuple::Size() == Tuple::Size(), "Wrong thread_lengths shape."); + const auto unrolled_thread_lengths = UnrollNestedTuple(thread_lengths); + const auto unrolled_shape = UnrollNestedTuple(shape); + constexpr auto dims = decltype(unrolled_thread_lengths)::Size(); + + using UnrolledStepsType = decltype(UnrollNestedTuple(steps)); + + using I1 = Number<1>; + + const auto transforms = generate_tuple( + [&](auto i) { + constexpr auto num_i = Number{}; + if constexpr(is_same_v, Tuple<>>) + { + // By default raked partition + const auto partition_stride = unrolled_thread_lengths.At(num_i); + return make_embed_transform(make_tuple(unrolled_shape.At(num_i)), + make_tuple(partition_stride)); + } + else if constexpr(!is_same_v, index_t>) + { + // Compiletime partition + if constexpr(is_same_v, I1>) + { + // raked + const auto partition_stride = unrolled_thread_lengths.At(num_i); + return make_embed_transform(make_tuple(unrolled_shape.At(num_i)), + make_tuple(partition_stride)); + } + else + { + // packed + return make_embed_transform(make_tuple(unrolled_shape.At(num_i)), + make_tuple(I1{})); + } + } + else + { + // Runtime partition + if(steps.At(num_i) == 1) + { + // raked + const auto partition_stride = unrolled_thread_lengths.At(num_i); + return make_embed_transform(make_tuple(unrolled_shape.At(num_i)), + make_tuple(partition_stride)); + } + else + { + // packed + return make_embed_transform(make_tuple(unrolled_shape.At(num_i)), + make_tuple(I1{})); + } + } + }, + Number{}); + + const auto lower_dims = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + const auto upper_dims = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); +} + +template +__host__ __device__ constexpr auto CalculateLayoutOffsetIdxImpl(const Tuple& thread_lengths, + const Tuple& steps, + index_t& thread_id) +{ + return generate_tuple( + [&](auto i) { + constexpr auto num_i = Number{}; + if constexpr(is_detected>>::value) + { + // if tuple then recurrence + if constexpr(is_same_v, Tuple<>>) + { + return CalculateLayoutOffsetIdxImpl( + thread_lengths.At(num_i), Tuple<>{}, thread_id); + } + else + { + return CalculateLayoutOffsetIdxImpl( + thread_lengths.At(num_i), steps.At(num_i), thread_id); + } + } + else + { + // Update thread_id after each dim + const auto dim_thread_id = thread_id % thread_lengths.At(num_i); + thread_id /= thread_lengths.At(num_i); + if constexpr(is_same_v, Tuple<>>) + { + return dim_thread_id; + } + else + { + // Apply step + return steps.At(num_i) * dim_thread_id; + } + } + }, + Number::Size()>{}); +} + +// Convert integer thread_idx to tuple index with steps applied +template +__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple& thread_lengths, + const Tuple& steps, + const index_t thread_id) +{ + // Create tmp thread_id copy for CalculateLayoutOffsetIdxImpl updates + index_t thread_id_copy = thread_id; + return CalculateLayoutOffsetIdxImpl(thread_lengths, steps, thread_id_copy); +} + +// Apply steps to index represented as tuple +template +__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple& steps, + const Tuple& block_idxs) +{ + return generate_tuple( + [&](auto i) { + constexpr auto num_i = Number{}; + if constexpr(is_detected>>::value) + { + // if tuple then recurrence + if constexpr(is_same_v, Tuple<>>) + { + return CalculateLayoutOffsetIdx(Tuple<>{}, block_idxs.At(num_i)); + } + else + { + return CalculateLayoutOffsetIdx(steps.At(num_i), block_idxs.At(num_i)); + } + } + else + { + if constexpr(is_same_v, Tuple<>>) + { + return block_idxs.At(num_i); + } + else + { + // apply step + return steps.At(num_i) * block_idxs.At(num_i); + } + } + }, + Number::Size()>{}); +} + +// User passes only shape per block to the make_local_tile function. This function calculates +// block layout based on the shape. +template +__host__ __device__ constexpr auto CalculateBlockLengths(const Tuple& shape, + const Tuple& tile_shape) +{ + return generate_tuple( + [&](auto i) { + constexpr auto num_i = Number{}; + if constexpr(is_detected>>::value) + { + // if tuple then recurrence + return CalculateBlockLengths(shape.At(num_i), tile_shape.At(num_i)); + } + else + { + return shape.At(num_i) / tile_shape.At(num_i); + } + }, + Number::Size()>{}); +} +} // namespace + +/** + * \brief Create local partition for thread. + * + * \param tensor Tensor for partition. + * \param thread_lengths Layout of threads. + * \param thread_id Thread index represented as integer. + * \param steps Thread step (default=1, raked partition) + * \return Partition tensor. + */ +template > +__host__ __device__ constexpr auto make_local_partition(const TensorType& tensor, + const ThreadLengthsTuple& thread_lengths, + const index_t thread_id, + const StepsTuple steps = StepsTuple{}) +{ + // Create shape, strides and layout for new partition tensor + const auto partition_shape = CalculateLocalPartitionShape(shape(tensor), thread_lengths); + // Create new descriptor and layout + const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor(); + auto partition_desc = + CalculateLocalPartitionDescriptor(shape(tensor), thread_lengths, steps, flatten_desc); + const auto partition_layout = Layout( + partition_shape, partition_desc); + // Calculate offset for new partition tensor + const auto offset_idx = CalculateLayoutOffsetIdx(thread_lengths, steps, thread_id); + const auto partition_offset = layout(tensor)(offset_idx); + return make_tensor(tensor.GetPointer() + partition_offset, + partition_layout); +} + +/** + * \brief Create local tile for thread block. + * + * \param tensor Tensor for partition. + * \param tile_shape Shapes of requested tile. + * \param block_idx Block index represented as tuple. + * \param steps Block step (default=1, raked partition) + * \return Tile tensor. + */ +template > +__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, + const BlockShapeTuple& tile_shape, + const BlockIdxTuple& block_idx, + const StepsTuple steps = StepsTuple{}) +{ + // Create block lengths, strides and layout for new tile tensor + const auto block_lengths = CalculateBlockLengths(shape(tensor), tile_shape); + // Create new descriptor and layout + const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor(); + auto tile_desc = + CalculateLocalPartitionDescriptor(tile_shape, block_lengths, steps, flatten_desc); + const auto tile_layout = Layout, decltype(tile_desc)>( + tile_shape, tile_desc); + // Calculate offset for new partition tensor + const auto offset_idx = CalculateLayoutOffsetIdx(steps, block_idx); + const auto tile_offset = layout(tensor)(offset_idx); + return make_tensor(tensor.GetPointer() + tile_offset, + tile_layout); +} + +} // namespace wrapper +} // namespace ck diff --git a/include/ck/wrapper/utils/tensor_utils.hpp b/include/ck/wrapper/utils/tensor_utils.hpp index 5f0dc3e500..1e932e62e1 100644 --- a/include/ck/wrapper/utils/tensor_utils.hpp +++ b/include/ck/wrapper/utils/tensor_utils.hpp @@ -27,12 +27,12 @@ using MemoryTypeEnum = AddressSpaceEnum; // Disable from doxygen docs generation /// @cond // forward declarations -template +template struct Layout; template @@ -98,11 +98,19 @@ using is_tuple = decltype(std::declval().IsTuple()); * \param layout Tensor layout. * \return Constructed tensor. */ -template -constexpr auto make_tensor(ElementType* pointer, const Layout& layout) +template +constexpr auto make_tensor(ElementType* pointer, + const Layout& layout) { - return Tensor( - pointer, layout); + return Tensor(pointer, layout); } /** @@ -112,19 +120,21 @@ constexpr auto make_tensor(ElementType* pointer, const Layout& l * \tparam NumVectors Number of vectors. * \tparam ScalarPerVector Scalars per vector. * \tparam ElementType Memory data type. - * \param layout Tensor layout. * \return Constructed tensor. */ template -constexpr auto make_register_tensor(const Layout& layout) + typename ElementType> +constexpr auto make_register_tensor() { - static_assert(!IsNestedTuple(Shape{}), "Register tensor with nested layout is not supported"); - return Tensor(layout); + const auto layout = make_layout(make_tuple(Number{}), make_tuple(Number<1>{})); + return Tensor>, + std::remove_const_t>, + NumVectors, + ScalarPerVector>(layout); } /** @@ -136,12 +146,15 @@ constexpr auto make_register_tensor(const Layout& layout) template -__host__ __device__ constexpr const auto& -layout(const Tensor& - tensor) +__host__ __device__ constexpr const auto& layout(const Tensor& tensor) { return tensor.GetLayout(); } @@ -157,12 +170,15 @@ template -__host__ __device__ constexpr index_t -size(const Tensor& - tensor) +__host__ __device__ constexpr auto size(const Tensor& tensor) { return size(tensor.GetLayout()); } @@ -178,12 +194,15 @@ template -__host__ __device__ constexpr index_t -rank(const Tensor& - tensor) +__host__ __device__ constexpr auto rank(const Tensor& tensor) { return rank(tensor.GetLayout()); } @@ -199,35 +218,19 @@ template -__host__ __device__ constexpr index_t -depth(const Tensor& - tensor) +__host__ __device__ constexpr auto depth(const Tensor& tensor) { return depth(tensor.GetLayout()); } -/** - * \brief Get Tensor strides. - * - * \param tensor Tensor to get strides from. - * \return Requsted strides. - */ -template -__host__ __device__ constexpr const auto& -stride(const Tensor& - tensor) -{ - return stride(tensor.GetLayout()); -} - /** * \brief Get Tensor shape. * @@ -237,12 +240,15 @@ stride(const Tensor -__host__ __device__ constexpr const auto& -shape(const Tensor& - tensor) +__host__ __device__ constexpr const auto& shape(const Tensor& tensor) { return shape(tensor.GetLayout()); } diff --git a/test/wrapper/CMakeLists.txt b/test/wrapper/CMakeLists.txt index 6b25c08a8a..6c3e29ab87 100644 --- a/test/wrapper/CMakeLists.txt +++ b/test/wrapper/CMakeLists.txt @@ -2,3 +2,7 @@ add_gtest_executable(test_layout test_layout.cpp) target_link_libraries(test_layout PRIVATE utility) add_gtest_executable(test_tensor test_tensor.cpp) target_link_libraries(test_tensor PRIVATE utility) +add_gtest_executable(test_copy test_copy.cpp) +target_link_libraries(test_copy PRIVATE utility) +add_gtest_executable(test_partition test_partition.cpp) +target_link_libraries(test_partition PRIVATE utility) diff --git a/test/wrapper/test_copy.cpp b/test/wrapper/test_copy.cpp new file mode 100644 index 0000000000..5cf09a54be --- /dev/null +++ b/test/wrapper/test_copy.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" + +// Test copy from Global to Global through LDS and VGPR +template +__global__ void TestCopyDevice(const InputTensor input_tensor, + OutputTensor output_tensor, + const BlockShape tile_shape, + const ThreadLayoutShape thread_layout, + const LocalTileSteps block_steps, + const LocalPartitionSteps thread_steps) +{ + __shared__ ck::index_t p_shared[ck::wrapper::size(tile_shape)]; + auto tensor_lds = ck::wrapper::make_tensor( + p_shared, ck::wrapper::make_layout(tile_shape)); + + const auto block_idxs = ck::make_tuple(ck::make_tuple(0, 0), blockIdx.x); + + // Get local tiles for global memory + const auto input_local_tile = + ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs, block_steps); + const auto output_local_tile = + ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs, block_steps); + + // Get partition per thread + const auto input_local_partition = ck::wrapper::make_local_partition( + input_local_tile, thread_layout, threadIdx.x, thread_steps); + auto lds_local_partition = + ck::wrapper::make_local_partition(tensor_lds, thread_layout, threadIdx.x, thread_steps); + auto output_local_partition = ck::wrapper::make_local_partition( + output_local_tile, thread_layout, threadIdx.x, thread_steps); + + // Allocate VGPR + constexpr ck::index_t scalar_per_vector = 1; + constexpr ck::index_t vgpr_size = ck::wrapper::size(lds_local_partition); + auto tensor_vgpr = ck::wrapper::make_register_tensor(); + + // Perform copy + ck::wrapper::copy(input_local_partition, lds_local_partition); + ck::wrapper::copy(lds_local_partition, tensor_vgpr); + ck::wrapper::copy(tensor_vgpr, output_local_partition); +} + +void PerformCopyGlobalToGlobalViaLDS() +{ + const auto shape = + ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<256>{}); + const auto strides = + ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<2>{}), ck::Number<4>{}); + const auto layout = ck::wrapper::make_layout(shape, strides); + + // 0, 1, 2, ..., size(shape) - 1 + std::vector input_data(ck::wrapper::size(shape)); + std::iota(input_data.begin(), input_data.end(), 0); + + // Global memory buffers + DeviceMem in_buf(ck::wrapper::size(layout) * sizeof(ck::index_t)); + DeviceMem out_buf(ck::wrapper::size(layout) * sizeof(ck::index_t)); + + in_buf.ToDevice(input_data.data()); + out_buf.SetZero(); + + // Create tensors for global memory + const auto input_tensor_global = ck::wrapper::make_tensor( + static_cast(in_buf.GetDeviceBuffer()), layout); + auto output_tensor_global = ck::wrapper::make_tensor( + static_cast(out_buf.GetDeviceBuffer()), layout); + + const auto thread_layout = + ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}), ck::Number<32>{}); + const auto tile_shape = + ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<64>{}); + + const auto thread_steps = + ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}), ck::Number<2>{}); + const auto block_steps = + ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}), ck::Number<64>{}); + + const ck::index_t grid_size = ck::math::integer_divide_ceil( + ck::wrapper::size(input_tensor_global), ck::wrapper::size(tile_shape)); + + const auto kernel = TestCopyDevice; + launch_and_time_kernel(StreamConfig{}, + kernel, + dim3(grid_size), + dim3(ck::wrapper::size(thread_layout)), + 0, + input_tensor_global, + output_tensor_global, + tile_shape, + thread_layout, + block_steps, + thread_steps); + + // Verify results + std::vector output_data(ck::wrapper::size(shape)); + out_buf.FromDevice(output_data.data()); + EXPECT_TRUE(ck::utils::check_err(output_data, input_data)); +} + +TEST(TestCopy, CopyGlobalToGlobalViaLDS) { PerformCopyGlobalToGlobalViaLDS(); } diff --git a/test/wrapper/test_layout.cpp b/test/wrapper/test_layout.cpp index 14a8b96462..a128a6d84f 100644 --- a/test/wrapper/test_layout.cpp +++ b/test/wrapper/test_layout.cpp @@ -84,7 +84,8 @@ TEST_F(TestWrapperLayout, 2d) ck::make_tuple(ck::Sequence<0>{})); const auto layout_runtime = ck::wrapper::make_layout(ck::make_tuple(d1, d0)); const auto layout_compiletime = - ck::wrapper::make_layout(ck::make_tuple(ck::Number{}, ck::Number{})); + ck::wrapper::make_layout(ck::make_tuple(ck::Number{}, ck::Number{}), + ck::make_tuple(ck::Number{}, ck::Number{})); std::vector> idxs; for(ck::index_t h = 0; h < d1; h++) @@ -435,19 +436,11 @@ TEST(TestLayoutHelpers, ShapeAndStrides) constexpr bool check_compiletime_shape = std::is_same_v>; - constexpr bool check_compiletime_strides = - std::is_same_v>; constexpr bool check_runtime_shape = std::is_same_v>; - constexpr bool check_runtime_strides = - std::is_same_v>; EXPECT_TRUE(check_compiletime_shape); - EXPECT_TRUE(check_compiletime_strides); EXPECT_TRUE(check_runtime_shape); - EXPECT_TRUE(check_runtime_strides); } TEST(TestLayoutHelpers, Hierarchical) diff --git a/test/wrapper/test_partition.cpp b/test/wrapper/test_partition.cpp new file mode 100644 index 0000000000..df56b879f6 --- /dev/null +++ b/test/wrapper/test_partition.cpp @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" + +TEST(TestPartition, LocalPartition) +{ + const auto shape = + ck::make_tuple(ck::make_tuple(ck::Number<16>{}, ck::Number<4>{}), ck::Number<4>{}); + const auto strides = + ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<16>{}), ck::Number<64>{}); + const auto layout = ck::wrapper::make_layout(shape, strides); + + std::vector data(ck::wrapper::size(layout)); + std::iota(data.begin(), data.end(), 0); + + const auto tensor = + ck::wrapper::make_tensor(data.data(), layout); + + const auto thread_steps = + ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<1>{}), ck::Number<1>{}); + const auto thread_layout = + ck::make_tuple(ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}), ck::Number<1>{}); + + for(ck::index_t thread_id = 0; thread_id < ck::wrapper::size(thread_layout); thread_id++) + { + const auto raked_partition = + ck::wrapper::make_local_partition(tensor, thread_layout, thread_id); + + const auto expected_partition_size = + ck::wrapper::size(tensor) / ck::wrapper::size(thread_layout); + EXPECT_EQ(ck::wrapper::size(raked_partition), expected_partition_size); + EXPECT_EQ(raked_partition(0), thread_id); + } + + for(ck::index_t thread_id = 0; thread_id < ck::wrapper::size(thread_layout); thread_id++) + { + const auto packed_partition = + ck::wrapper::make_local_partition(tensor, thread_layout, thread_id, thread_steps); + + const auto expected_partition_size = + ck::wrapper::size(tensor) / ck::wrapper::size(thread_layout); + const auto expected_partition_first_val = thread_id * ck::wrapper::size<0, 0>(thread_steps); + EXPECT_EQ(ck::wrapper::size(packed_partition), expected_partition_size); + EXPECT_EQ(packed_partition(0), expected_partition_first_val); + } +} + +TEST(TestPartition, LocalTile) +{ + const auto shape = + ck::make_tuple(ck::make_tuple(ck::Number<16>{}, ck::Number<4>{}), ck::Number<4>{}); + const auto strides = + ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<16>{}), ck::Number<64>{}); + const auto layout = ck::wrapper::make_layout(shape, strides); + + std::vector data(ck::wrapper::size(layout)); + std::iota(data.begin(), data.end(), 0); + + const auto tensor = + ck::wrapper::make_tensor(data.data(), layout); + + const auto block_steps = + ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}), ck::Number<2>{}); + const auto block_shape = + ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}), ck::Number<2>{}); + const auto block_layout = + ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}), ck::Number<2>{}); + + std::vector, ck::index_t>> block_idxs; + for(ck::index_t x = 0; x < ck::wrapper::size<0, 0>(block_layout); x++) + { + for(ck::index_t y = 0; y < ck::wrapper::size<0, 1>(block_layout); y++) + { + for(ck::index_t z = 0; z < ck::wrapper::size<1>(block_layout); z++) + { + block_idxs.emplace_back(ck::make_tuple(x, y), z); + } + } + } + + for(const auto& block_idx : block_idxs) + { + const auto raked_tile = ck::wrapper::make_local_tile(tensor, block_shape, block_idx); + + const auto expected_tile_size = ck::wrapper::size(block_shape); + EXPECT_EQ(ck::wrapper::size(raked_tile), expected_tile_size); + EXPECT_EQ(raked_tile(0), layout(block_idx)); + } + + for(const auto& block_idx : block_idxs) + { + const auto packed_tile = + ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_steps); + + const auto expected_tile_size = ck::wrapper::size(block_shape); + const auto expected_tile_first_val = + ck::wrapper::size<0, 0>(block_idx) * ck::wrapper::size<0, 0>(block_shape) * + ck::wrapper::size<0, 0>(strides) + + ck::wrapper::size<0, 1>(block_idx) * ck::wrapper::size<0, 1>(block_shape) * + ck::wrapper::size<0, 1>(strides) + + ck::wrapper::size<1>(block_idx) * ck::wrapper::size<1>(block_shape) * + ck::wrapper::size<1>(strides); + EXPECT_EQ(ck::wrapper::size(packed_tile), expected_tile_size); + EXPECT_EQ(packed_tile(0), expected_tile_first_val); + } +} diff --git a/test/wrapper/test_tensor.cpp b/test/wrapper/test_tensor.cpp index 74cf7f1316..2d4d6f2750 100644 --- a/test/wrapper/test_tensor.cpp +++ b/test/wrapper/test_tensor.cpp @@ -108,7 +108,6 @@ __global__ void TestTensorReadWriteDevice(void* data, void* success) bool* casted_success_ptr = static_cast(success); const auto layout = ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(2, 2), 2)); - constexpr auto register_layout = ck::wrapper::make_layout(ck::make_tuple(ck::Number<8>{})); auto tensor_global = ck::wrapper::make_tensor(casted_data_ptr, layout); @@ -116,11 +115,11 @@ __global__ void TestTensorReadWriteDevice(void* data, void* success) auto tensor_vgpr = ck::wrapper::make_register_tensor(register_layout); + ck::index_t>(); auto tensor_sgpr = ck::wrapper::make_register_tensor(register_layout); + ck::index_t>(); InitTensor(tensor_global); InitTensor(tensor_lds); @@ -151,7 +150,7 @@ TEST(TestTensor, ReadWriteGlobalLdsRegistersMemory) TestTensorReadWriteDevice, dim3(1), dim3(1), - nelems * sizeof(ck::index_t), + 0, data_buf.GetDeviceBuffer(), success_buf.GetDeviceBuffer()); @@ -173,33 +172,45 @@ TEST(TestTensor, Slicing) auto tensor2x2x2 = tensor(ck::make_tuple(ck::wrapper::slice(2), ck::wrapper::slice(2)), ck::wrapper::slice(2)); + EXPECT_EQ(tensor2x2x2(0), layout(ck::make_tuple(ck::make_tuple(0, 0), 0))); EXPECT_EQ(ck::wrapper::rank(tensor2x2x2), 2); EXPECT_EQ(ck::wrapper::depth(tensor2x2x2), 2); EXPECT_EQ(ck::wrapper::size(tensor2x2x2), 8); EXPECT_TRUE(TestTensorCheck1d(tensor2x2x2)); auto tensor2x2 = tensor(ck::make_tuple(1, ck::wrapper::slice(2)), ck::wrapper::slice(2)); + EXPECT_EQ(tensor2x2(0), layout(ck::make_tuple(ck::make_tuple(1, 0), 0))); EXPECT_EQ(ck::wrapper::rank(tensor2x2), 2); EXPECT_EQ(ck::wrapper::depth(tensor2x2), 2); EXPECT_EQ(ck::wrapper::size(tensor2x2), 4); - EXPECT_TRUE(TestTensorCheck1d(tensor2x2, layout(ck::make_tuple(ck::make_tuple(1, 0), 0)))); + EXPECT_TRUE(TestTensorCheck1d(tensor2x2)); auto tensor1x1 = tensor(ck::make_tuple(1, ck::wrapper::slice(1, 2)), ck::wrapper::slice(1, 2)); + EXPECT_EQ(tensor1x1(0), layout(ck::make_tuple(ck::make_tuple(1, 1), 1))); EXPECT_EQ(rank(tensor1x1), 2); EXPECT_EQ(depth(tensor1x1), 2); EXPECT_EQ(size(tensor1x1), 1); - EXPECT_TRUE(TestTensorCheck1d(tensor1x1, layout(ck::make_tuple(ck::make_tuple(1, 1), 1)))); + EXPECT_TRUE(TestTensorCheck1d(tensor1x1)); auto tensor2 = tensor(ck::make_tuple(1, 1), ck::wrapper::slice(0, 2)); + EXPECT_EQ(tensor2(0), layout(ck::make_tuple(ck::make_tuple(1, 1), 0))); EXPECT_EQ(ck::wrapper::rank(tensor2), 1); EXPECT_EQ(ck::wrapper::depth(tensor2), 1); EXPECT_EQ(ck::wrapper::size(tensor2), 2); - EXPECT_TRUE(TestTensorCheck1d(tensor2, layout(ck::make_tuple(ck::make_tuple(1, 1), 0)))); + EXPECT_TRUE(TestTensorCheck1d(tensor2)); + + auto tensor2_v2 = tensor(2, ck::wrapper::slice(0, 2)); + EXPECT_EQ(tensor2_v2(0), layout(ck::make_tuple(2, 0))); + EXPECT_EQ(ck::wrapper::rank(tensor2_v2), 1); + EXPECT_EQ(ck::wrapper::depth(tensor2_v2), 1); + EXPECT_EQ(ck::wrapper::size(tensor2_v2), 2); + EXPECT_TRUE(TestTensorCheck1d(tensor2_v2)); // negative indexing auto tensor1x2 = tensor(ck::make_tuple(1, ck::wrapper::slice(0, -2)), ck::wrapper::slice()); + EXPECT_EQ(tensor1x2(0), layout(ck::make_tuple(ck::make_tuple(1, 0), 0))); EXPECT_EQ(rank(tensor1x2), 2); EXPECT_EQ(depth(tensor1x2), 2); EXPECT_EQ(size(tensor1x2), 2); - EXPECT_TRUE(TestTensorCheck1d(tensor1x2, layout(ck::make_tuple(ck::make_tuple(1, 0), 0)))); + EXPECT_TRUE(TestTensorCheck1d(tensor1x2)); } From fbf31a2ea3f68f8741d07bbc45a16eb674b35108 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 3 Jan 2024 07:56:44 -0800 Subject: [PATCH 40/75] fix the cmake option syntax (#1117) --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a65c90e15d..bdeba33eac 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -180,8 +180,8 @@ elseif(CK_PARALLEL_COMPILE_JOBS) endif() -option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF) -option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) +option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) +option(USE_OPT_NAVI3X "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) From aa3e2d7967ec9d1316a0a015b4935f6c6fc8b21b Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Thu, 4 Jan 2024 08:33:19 -0800 Subject: [PATCH 41/75] Transpose profiler fix (#1114) * added working example for 5D input using 1D kernel * example with 5D input tensor and 2d kernel - not working: issues with arguments * added updated version of 3d device op - changed descriptors/dims * added example file to check kernel * fixed descriptor and isSupportedArgument stride problem * added and modified kernel for 3d - updated tids/loop * adding some more 5d example files * fixed some issues * changes made for testing * working version: fixed error in stride for A, still a bit inefficient * cleaned up formatting/comments * updating formatting * more formatting fixes * fixing cmake, adding back gpu targets in cmake script * adding client example * added instances for client example * fixed errors in client example * implemented client ex with device_elementwise.hpp and device_elementwise_3d_impl.hpp * removed extra files * minor formatting and naming fixes * adding test files and profiler * fixing minor error * minor fix * removed unneccesary comments, renamed files * updated instance list for client example, added different layout example * removing instances * fixed error in instance generation * remove comments * update profiler and client example tensor layouts * fixed errors in test/profiler * updated vector dim access to enable vector load * updated test/profiler files * updated example with 1d kernel * updating profiler * renamed files * disabled device op for MI300 * skip elementwise_permute_2d on gfx94x * Update CMakeLists.txt * fixing CMake - disabling some GPU targets * added transpose profiler to CMake * fixed transpose profiler errors * fixed instances for tests/profiler * cleaned up code in transpose profiler source code * added some comments, updated copyright * made function arguments const where possible --------- Co-authored-by: Jing Zhang Co-authored-by: Jing Zhang Co-authored-by: zjing14 --- .../elementwise_permute_3d.cpp | 12 +- .../transpose/device_transpose_instance.hpp | 13 +- .../profiler/profile_transpose_impl.hpp | 12 +- profiler/src/CMakeLists.txt | 2 + profiler/src/profile_transpose.cpp | 112 ++++++++++++++++++ test/transpose/test_transpose.cpp | 34 ++++-- test/transpose/test_transpose_ut_cases.inc | 28 ----- test/transpose/test_transpose_util.hpp | 54 --------- 8 files changed, 151 insertions(+), 116 deletions(-) create mode 100644 profiler/src/profile_transpose.cpp delete mode 100644 test/transpose/test_transpose_ut_cases.inc delete mode 100644 test/transpose/test_transpose_util.hpp diff --git a/example/44_elementwise_permute/elementwise_permute_3d.cpp b/example/44_elementwise_permute/elementwise_permute_3d.cpp index 669785a545..b061c0da34 100644 --- a/example/44_elementwise_permute/elementwise_permute_3d.cpp +++ b/example/44_elementwise_permute/elementwise_permute_3d.cpp @@ -14,8 +14,8 @@ using F16 = ck::half_t; using F32 = float; -using ADataType = F16; -using BDataType = F16; +using ADataType = F32; +using BDataType = F32; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using DeviceElementwisePermuteInstance = @@ -25,10 +25,10 @@ using DeviceElementwisePermuteInstance = 2, // NumDim_m, {N, C} 2, // NumDim_n, {H, W} 1, // NumDim_k, {D} - 8, // MPerThread - 8, // NPerThread - 8, // KPerThread - ck::Sequence<8>, // InScalarPerVectorSeq + 4, // MPerThread + 4, // NPerThread + 4, // KPerThread + ck::Sequence<4>, // InScalarPerVectorSeq ck::Sequence<4>>; // OutScalarPerVectorSeq template diff --git a/library/include/ck/library/tensor_operation_instance/gpu/transpose/device_transpose_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/transpose/device_transpose_instance.hpp index 817e717a89..6ac0871a80 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/transpose/device_transpose_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/transpose/device_transpose_instance.hpp @@ -21,20 +21,19 @@ template using S = ck::Sequence; using device_transpose_f16_instances = std::tuple< - // FOR 16, 32, 16, 32, 16 // clang-format off - DeviceElementwise3dImpl, ck::Tuple, PassThrough, 2, 2, 1, 8, 8, 8, ck::Sequence<1>, ck::Sequence<1>>, - DeviceElementwise3dImpl, ck::Tuple, PassThrough, 2, 2, 1, 8, 1, 1, ck::Sequence<1>, ck::Sequence<1>>, - DeviceElementwise3dImpl, ck::Tuple, PassThrough, 2, 2, 1, 8, 4, 4, ck::Sequence<1>, ck::Sequence<1>> + DeviceElementwise3dImpl, ck::Tuple, PassThrough, 2, 2, 1, 8, 8, 8, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwise3dImpl, ck::Tuple, PassThrough, 2, 2, 1, 8, 8, 8, ck::Sequence<8>, ck::Sequence<4>>, + DeviceElementwise3dImpl, ck::Tuple, PassThrough, 2, 2, 1, 4, 4, 8, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwise3dImpl, ck::Tuple, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<1>, ck::Sequence<1>> // clang-format on >; using device_transpose_f32_instances = std::tuple< - // for 16, 8, 16, 32, 8 -> test with instances for fp16 // clang-format off DeviceElementwise3dImpl, ck::Tuple, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<1>, ck::Sequence<1>>, - DeviceElementwise3dImpl, ck::Tuple, PassThrough, 2, 2, 1, 4, 8, 4, ck::Sequence<1>, ck::Sequence<1>>, - DeviceElementwise3dImpl, ck::Tuple, PassThrough, 2, 2, 1, 4, 8, 8, ck::Sequence<1>, ck::Sequence<1>> + DeviceElementwise3dImpl, ck::Tuple, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<4>, ck::Sequence<1>>, + DeviceElementwise3dImpl, ck::Tuple, PassThrough, 2, 2, 1, 4, 4, 4, ck::Sequence<4>, ck::Sequence<4>> // clang-format on >; diff --git a/profiler/include/profiler/profile_transpose_impl.hpp b/profiler/include/profiler/profile_transpose_impl.hpp index 3dae9ef48b..a4f2cb6763 100644 --- a/profiler/include/profiler/profile_transpose_impl.hpp +++ b/profiler/include/profiler/profile_transpose_impl.hpp @@ -25,7 +25,7 @@ namespace ck { namespace profiler { template -void host_elementwise4D(HostTensorB& B_nchwd, const HostTensorA& A_ncdhw, Functor functor) +void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor) { for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n) for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c) @@ -34,7 +34,7 @@ void host_elementwise4D(HostTensorB& B_nchwd, const HostTensorA& A_ncdhw, Functo for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w) { auto a_val = A_ncdhw(n, c, d, h, w); - functor(B_nchwd(n, c, h, w, d), a_val); + functor(B_ndhwc(n, d, h, w, c), a_val); } } @@ -77,8 +77,6 @@ bool profile_transpose_impl(int do_verification, using ElementOp = ck::tensor_operation::element_wise::PassThrough; - // const auto element_op = ElementOp{}; - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); @@ -118,6 +116,7 @@ bool profile_transpose_impl(int do_verification, // re-init C to zero before profiling next kernel b_device_buf.SetZero(); + // run for verification invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); if(do_verification) @@ -136,6 +135,7 @@ bool profile_transpose_impl(int do_verification, std::string op_name = op_ptr->GetTypeString(); + // run for timing purposes float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); @@ -153,10 +153,6 @@ bool profile_transpose_impl(int do_verification, std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << op_name << std::endl; - // pass = pass & ck::utils::check_err(b_device_result, b_host_result); - pass &= ck::utils::check_err( - b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); - if(tflops > best_tflops) { best_op_name = op_name; diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 5144785aa0..68ef04ed11 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -29,6 +29,7 @@ set(PROFILER_SOURCES profile_batchnorm_infer.cpp profile_grouped_conv_bwd_data.cpp profile_conv_tensor_rearrange.cpp + profile_transpose.cpp ) if(DL_KERNELS) @@ -91,6 +92,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_d target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) diff --git a/profiler/src/profile_transpose.cpp b/profiler/src/profile_transpose.cpp new file mode 100644 index 0000000000..d04c9fa2c4 --- /dev/null +++ b/profiler/src/profile_transpose.cpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_transpose_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct DataType +{ + F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F16_F16, // 1 +}; + +#define OP_NAME "transpose" +#define OP_DESC "Transpose" + +struct TransposeArgParser +{ + std::unordered_map> long_opts = {{"lengths", {}}}; + + bool parse_opt(const int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + const int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +static void print_helper_msg() +{ + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: verification (0: no; 1: yes)\n"); + printf("arg4: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg5: print tensor value (0: no; 1: yes)\n"); + printf("arg6: time kernel (0=no, 1=yes)\n"); + printf("arg7: --lengths: N, C, D, H, W\n"); +} + +int profile_transpose(int argc, char* argv[]) +{ + if(argc != 7) + { + print_helper_msg(); + exit(1); + } + TransposeArgParser arg_parser; + + const auto data_type = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + arg_parser(argc, argv); + const std::vector lengths = arg_parser.long_opts["lengths"]; + + using F32 = float; + using F16 = ck::half_t; + + auto profile = [&](auto a_type, auto b_type) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + constexpr ck::index_t NumDim = 5; + + bool pass = ck::profiler::profile_transpose_impl( + do_verification, init_method, do_log, time_kernel, lengths); + + return pass ? 0 : 1; + }; + + if(data_type == DataType::F32_F32_F32_F32_F32) + { + return profile(F32{}, F32{}); + } + else if(data_type == DataType::F16_F16_F16_F16_F16) + { + return profile(F16{}, F16{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_transpose); diff --git a/test/transpose/test_transpose.cpp b/test/transpose/test_transpose.cpp index 74991c62da..ead622b4d7 100644 --- a/test/transpose/test_transpose.cpp +++ b/test/transpose/test_transpose.cpp @@ -1,27 +1,35 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - #include "gtest/gtest.h" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "test_transpose_util.hpp" +#include "profiler/profile_transpose_impl.hpp" using F16 = ck::half_t; using F32 = float; +using ck::index_t; template class TestTranspose : public ::testing::Test { + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + + void Run() + { + std::vector> lengths = { + {4, 16, 16, 32, 5}, {8, 16, 16, 32, 8} /**{32, 16, 16, 32, 8},**/}; + + for(auto length : lengths) + { + bool success = ck::profiler::profile_transpose_impl( + true, 2, false, false, length); + EXPECT_TRUE(success); + } + } }; -// clang-format off -using KernelTypes = ::testing::Types< - std::tuple< F16, F16>, - std::tuple< F32, F32> - >; -// clang-format on +using KernelTypes = ::testing::Types, std::tuple>; TYPED_TEST_SUITE(TestTranspose, KernelTypes); - -//#include "test_transpose_ut_cases.inc" +TYPED_TEST(TestTranspose, Test_FP16) { this->Run(); } +TYPED_TEST(TestTranspose, Test_FP32) { this->Run(); } diff --git a/test/transpose/test_transpose_ut_cases.inc b/test/transpose/test_transpose_ut_cases.inc deleted file mode 100644 index 59a2a6c72c..0000000000 --- a/test/transpose/test_transpose_ut_cases.inc +++ /dev/null @@ -1,28 +0,0 @@ -#pragma once - -TYPED_TEST(TestTranspose, Test1) -{ - // for 16, 8, 16, 32, 8 - std::vector Ms{1, 2, 3, 4, 5, 6}; - std::vector lengths{16, 8, 16, 32, 8}; - /**constexpr int N = 16; - constexpr int C = 8; - constexpr int D = 16; - constexpr int H = 32; - constexpr int W = 8;**/ - - this->Run(); -} - -TYPED_TEST(TestTranpose, Test2) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - std::vector lengths{16, 8, 16, 32, 16}; - /**constexpr int N = 16; - constexpr int C = 8; - constexpr int D = 16; - constexpr int H = 32; - constexpr int W = 8;**/ - - this->Run(); -} diff --git a/test/transpose/test_transpose_util.hpp b/test/transpose/test_transpose_util.hpp deleted file mode 100644 index 4bc25a6032..0000000000 --- a/test/transpose/test_transpose_util.hpp +++ /dev/null @@ -1,54 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "include/ck/utility/data_type.hpp" -#include "profiler/profile_transpose_impl.hpp" - -namespace ck { -namespace test { - -template -class TestTranspose : public testing::Test -{ - using F32 = float; - - protected: - using ADataType = std::tuple_element_t<0, Tuple>; - using BDataType = std::tuple_element_t<1, Tuple>; - - public: - static constexpr bool verify_ = true; - static constexpr int init_method_ = 1; // decimal value initialization - static constexpr bool log_ = false; - static constexpr bool bench_ = false; // measure kernel performance - std::vector> lengths_ = {{16, 32, 16, 32, 16}, {16, 8, 16, 32, 8}}; - - void Run() - { - for(auto length : this->lengths_) - { - this->RunSingle(length); - } - } - - void RunSingle() - { - bool pass = ck::profiler::profile_transpose_impl( - verify_, init_method_, log_, bench_, lengths_); - EXPECT_TRUE(pass); - } -}; - -} // namespace test -} // namespace ck From 11e27522618996406d2423fddbb6bf55fedbc770 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Thu, 4 Jan 2024 17:38:24 +0100 Subject: [PATCH 42/75] Add missing copyrights in elementwise_permute examples (#1118) --- example/44_elementwise_permute/elementwise_permute.cpp | 3 +++ example/44_elementwise_permute/elementwise_permute_3d.cpp | 3 +++ example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp | 3 +++ .../44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp | 3 +++ .../44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp | 3 +++ .../44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp | 3 +++ .../44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp | 3 +++ .../44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp | 3 +++ 8 files changed, 24 insertions(+) diff --git a/example/44_elementwise_permute/elementwise_permute.cpp b/example/44_elementwise_permute/elementwise_permute.cpp index b40c5e3411..24e161c6d3 100644 --- a/example/44_elementwise_permute/elementwise_permute.cpp +++ b/example/44_elementwise_permute/elementwise_permute.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #include #include diff --git a/example/44_elementwise_permute/elementwise_permute_3d.cpp b/example/44_elementwise_permute/elementwise_permute_3d.cpp index b061c0da34..f3aca57c35 100644 --- a/example/44_elementwise_permute/elementwise_permute_3d.cpp +++ b/example/44_elementwise_permute/elementwise_permute_3d.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #include #include diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp index 3b5a255410..8e9bc64ab6 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #include #include diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp index 5d11ddfaea..30231a3758 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #include #include diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp index f496d26a8a..9d5fdc0cc7 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp index dd7883cd21..7d215cef24 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #include #include diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp index 619f481357..69e411c59a 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #include #include diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp index b1f0e12f49..69f40fe165 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #include #include From d89700201b2cba848f531442b2f51bbaa1d65da7 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 5 Jan 2024 08:01:33 -0800 Subject: [PATCH 43/75] Add a docker for testing CK with rocm6.0.1 RC1. (#1119) * add docker for rocm6.0.1 rc1 * modify the path to clang for test compilers in CI * fix the hipcc/clang path for test compilers in CI * fix the dockerfile for older rocm versions --- Dockerfile | 22 ++++++++++++++-------- Jenkinsfile | 10 +++++----- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/Dockerfile b/Dockerfile index b9339ec5d4..c058c85723 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,12 +16,18 @@ RUN apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg -RUN wget https://repo.radeon.com/amdgpu-install/6.0/ubuntu/focal/amdgpu-install_6.0.60000-1_all.deb --no-check-certificate -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.0.60000-1_all.deb - -RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ - sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list' +RUN if [ "$ROCMVERSION" != "6.0.1" ]; then \ + sh -c "wget https://repo.radeon.com/amdgpu-install/6.0/ubuntu/focal/amdgpu-install_6.0.60000-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.0.60000-1_all.deb && \ + wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ + sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ + sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ + elif [ "$ROCMVERSION" = "6.0.1" ] && [ "$compiler_version" = "rc1" ]; then \ + sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.0-20.04-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.0-20.04-1_all.deb && \ + sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.0.1 rel-95 > /etc/apt/sources.list.d/rocm-build.list' && \ + amdgpu-repo --amdgpu-build=1704947; \ + fi RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" RUN amdgpu-install -y --usecase=rocm --no-dkms @@ -111,7 +117,7 @@ ENV compiler_commit=$compiler_commit RUN sh -c "echo compiler version = '$compiler_version'" RUN sh -c "echo compiler commit = '$compiler_commit'" -RUN if [ "$compiler_version" != "" ] && [ "$compiler_commit" = "" ]; then \ +RUN if [ [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "amd-mainline-open" ] ] && [ "$compiler_commit" = "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ @@ -119,7 +125,7 @@ RUN if [ "$compiler_version" != "" ] && [ "$compiler_commit" = "" ]; then \ else echo "using the release compiler"; \ fi -RUN if [ "$compiler_version" != "" ] && [ "$compiler_commit" != "" ]; then \ +RUN if [ [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "amd-mainline-open" ] ] && [ "$compiler_commit" != "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ diff --git a/Jenkinsfile b/Jenkinsfile index 2bb48b85ce..268cc7606f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -33,7 +33,7 @@ def runShell(String command){ def getDockerImageName(){ def img - if (params.ROCMVERSION != "6.1"){ + if (params.ROCMVERSION != "6.0.1"){ if (params.COMPILER_VERSION == "") { img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" } @@ -84,7 +84,7 @@ def build_compiler(){ compiler = '/opt/rocm/bin/hipcc' } else{ - if (params.COMPILER_VERSION != "" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ compiler = "/llvm-project/build/bin/clang++" } else{ @@ -293,7 +293,7 @@ def buildHipClangJob(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION != "" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -348,7 +348,7 @@ def runCKProfiler(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION != "" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -479,7 +479,7 @@ def Build_CK(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION != "" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } From 61545bda3568b38fbf08f218a6e4091da83fc32e Mon Sep 17 00:00:00 2001 From: Bartlomiej Wroblewski Date: Fri, 5 Jan 2024 18:36:02 +0100 Subject: [PATCH 44/75] Update the recommended version of ROCm in docs (#1110) --- docs/dockerhub.rst | 6 +++--- docs/tutorial_hello_world.rst | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/dockerhub.rst b/docs/dockerhub.rst index 66ec91096e..cf420030ff 100644 --- a/docs/dockerhub.rst +++ b/docs/dockerhub.rst @@ -30,7 +30,7 @@ run a docker container:: --group-add sudo \ -w /root/workspace \ -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ - rocm/composable_kernel:ck_ub20.04_rocm5.6 \ + rocm/composable_kernel:ck_ub20.04_rocm6.0 \ /bin/bash and build the CK:: @@ -76,11 +76,11 @@ The docker images have everything you need for running CK including: Which image is right for me? ------------------------------------- -Let's take a look at the image naming, for example ``ck_ub20.04_rocm5.6``. The image specs are: +Let's take a look at the image naming, for example ``ck_ub20.04_rocm6.0``. The image specs are: * ``ck`` - made for running Composable Kernel; * ``ub20.04`` - based on Ubuntu 20.04; -* ``rocm5.6`` - ROCm platform version 5.6. +* ``rocm6.0`` - ROCm platform version 6.0. So just pick the right image for your project dependencies and you're all set. diff --git a/docs/tutorial_hello_world.rst b/docs/tutorial_hello_world.rst index bfb197e085..6b8154d462 100644 --- a/docs/tutorial_hello_world.rst +++ b/docs/tutorial_hello_world.rst @@ -72,8 +72,8 @@ First let's clone the library and rebase to the tested version:: To make our lives easier we prepared `docker images `_ with all the necessary dependencies. Pick the right image and create a container. In this tutorial we use -``rocm/composable_kernel:ck_ub20.04_rocm5.6`` image, it is based on Ubuntu 20.04 and -ROCm v5.6. +``rocm/composable_kernel:ck_ub20.04_rocm6.0`` image, it is based on Ubuntu 20.04 and +ROCm v6.0. If your current folder is ``${HOME}``, start the docker container with:: @@ -83,7 +83,7 @@ If your current folder is ``${HOME}``, start the docker container with:: --group-add sudo \ -w /root/workspace \ -v ${HOME}:/root/workspace \ - rocm/composable_kernel:ck_ub20.04_rocm5.6 \ + rocm/composable_kernel:ck_ub20.04_rocm6.0 \ /bin/bash If your current folder is different from ``${HOME}``, adjust the line ``-v ${HOME}:/root/workspace`` From a39163814e437a1fe4a1c7025956d3fb4810afa8 Mon Sep 17 00:00:00 2001 From: randyh62 <42045079+randyh62@users.noreply.github.com> Date: Fri, 5 Jan 2024 11:04:01 -0800 Subject: [PATCH 45/75] doc reorg and edits (#1112) * doc reorg and edits * Update wrapper.rst with changes from PR #1098 * Update docs/dockerhub.rst Co-authored-by: Bartlomiej Wroblewski * Update docs/index.rst Co-authored-by: Bartlomiej Wroblewski * Update docs/what-is-ck.rst Co-authored-by: Bartlomiej Wroblewski * Update docs/what-is-ck.rst Restored to 4 bullets, with additional text for wrapper. Co-authored-by: Bartlomiej Wroblewski * Update docs/Contributors_Guide.rst Co-authored-by: Lisa * Update API_Reference_Guide.rst using sentence case for title * updated index structure per Lisa * separate docker hub and tutorial --------- Co-authored-by: Bartlomiej Wroblewski Co-authored-by: Lisa Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- docs/API_Reference_Guide.rst | 16 +-- docs/Contributors_Guide.rst | 33 +++--- docs/Supported_Primitives_Guide.rst | 20 ++-- docs/dockerhub.rst | 124 +++++++++++------------ docs/index.rst | 68 +++++-------- docs/license.md | 2 + docs/license.rst | 6 -- docs/sphinx/_toc.yml.in | 23 +++-- docs/tutorial_hello_world.rst | 149 ++++++++++------------------ docs/what-is-ck.rst | 41 ++++++++ docs/wrapper.rst | 12 ++- 11 files changed, 248 insertions(+), 246 deletions(-) create mode 100644 docs/license.md delete mode 100644 docs/license.rst create mode 100644 docs/what-is-ck.rst diff --git a/docs/API_Reference_Guide.rst b/docs/API_Reference_Guide.rst index f21d43c593..22222b0cf0 100644 --- a/docs/API_Reference_Guide.rst +++ b/docs/API_Reference_Guide.rst @@ -1,11 +1,13 @@ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation -******************* -API Reference Guide -******************* +.. _api-reference: + +******************************************************************** +API reference guide +******************************************************************** -================= -Introduction -================= This document contains details of the APIs for the Composable Kernel (CK) library and introduces some of the key design principles that are used to write new classes that extend CK functionality. @@ -30,7 +32,7 @@ DeviceMem Kernels For Flashattention --------------------------- -The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This sections lists +The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This section lists the classes that are used in the CK GPU implementation of Flashattention. **Gridwise classes** diff --git a/docs/Contributors_Guide.rst b/docs/Contributors_Guide.rst index 41cb8f1915..b91984357a 100644 --- a/docs/Contributors_Guide.rst +++ b/docs/Contributors_Guide.rst @@ -1,9 +1,14 @@ -=================== -Contributor's Guide -=================== +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation -This chapter explains how to get started contributing to the Composable Kernel project and what are -the contributing rules. +.. _contributing-to: + +******************************************************************** +Contributor's guide +******************************************************************** + +This chapter explains the rules for contributing to the Composable Kernel project, and how to contribute. Getting started =============== @@ -14,23 +19,21 @@ Getting started build the library. You can also find some of this information in the `README file `_ on the project's GitHub page. -#. **Additional reading:** We also recommend reading a `blog post +#. **Additional reading:** The blog post `AMD Composable Kernel library: efficient fused kernels for AI apps with just a few lines of code `_ provides a deeper understanding of the CK library and showcases its performance capabilities. `_ - from the AMD Community portal. It offers a deeper understanding of the library's objectives and - showcases its performance capabilities. + from the AMD Community portal. It offers a deeper understanding of the library's objectives and showcases its performance capabilities. #. **General information:** For broader information about AMD products, consider exploring the `AMD Developer Central portal `_. -How do I contribute +How to contribute =================== -We deeply value contributions from our users. You can make an impact by reporting issues or -proposing code enhancements through pull requests. +You can make an impact by reporting issues or proposing code enhancements through pull requests. Reporting issues ---------------- -We use `Github issues `_ +Use `Github issues `_ to track public bugs and enhancement requests. If you encounter an issue with the library, please check if the problem has already been @@ -59,7 +62,7 @@ issue. All reported issues must include: * How frequently does this issue happen? Does it reproduce every time? Or is it a sporadic issue? -Before sumbitting any issue, ensure you have addressed all relevant questions from the checklist. +Before submitting any issue, ensure you have addressed all relevant questions from the checklist. Creating Pull Requests ---------------------- @@ -68,7 +71,7 @@ You can submit `Pull Requests (PR) on GitHub `_. All contributors are required to develop their changes on a separate branch and then create a -pull requrest to merge their changes into the `develop` branch, which is the default +pull request to merge their changes into the `develop` branch, which is the default development branch in the Composable Kernel project. All external contributors must use their own forks of the project to develop their changes. @@ -99,4 +102,4 @@ When submitting a Pull Request you should: Following the above guidelines ensures a seamless review process and faster assistance from our end. -Thank you for your commitment to enhancing the Composable Kernel project! We look forward to collaborating with you. +Thank you for your commitment to enhancing the Composable Kernel project! diff --git a/docs/Supported_Primitives_Guide.rst b/docs/Supported_Primitives_Guide.rst index 3462283d90..e24acf5656 100644 --- a/docs/Supported_Primitives_Guide.rst +++ b/docs/Supported_Primitives_Guide.rst @@ -1,16 +1,20 @@ -========================== -Supported Primitives Guide -========================== +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation -This document contains details of supported primitives in Composable Kernel (CK). In contrast to the -API Reference Guide, the Supported Primitives Guide is an introduction to the math which underpins -the algorithms implemented in CK. +.. _supported-primitives: + +******************************************************************** +Supported Primitives Guide +******************************************************************** + +This document contains details of supported primitives in Composable Kernel (CK). In contrast to the API Reference Guide, the Supported Primitives Guide is an introduction to the math which underpins the algorithms implemented in CK. ------------ Softmax ------------ -For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can decompose the +For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` you can decompose the softmax of concatenated :math:`x = [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ]` as, .. math:: @@ -27,7 +31,7 @@ where :math:`f(x^{(j)}) = \exp( x^{(j)} - m(x^{(j)}) )` is of size :math:`B` and :math:`z(x^{(j)}) = f(x_1^{(j)})+ \ldots+ f(x_B^{(j)})` is a scalar. For a matrix :math:`X` composed of :math:`T_r \times T_c` tiles, :math:`X_{ij}`, of size -:math:`B_r \times B_c` we can compute the row-wise softmax as follows. +:math:`B_r \times B_c` you can compute the row-wise softmax as follows. For :math:`j` from :math:`1` to :math:`T_c`, and :math:`i` from :math:`1` to :math:`T_r` calculate, diff --git a/docs/dockerhub.rst b/docs/dockerhub.rst index cf420030ff..fb89bef72b 100644 --- a/docs/dockerhub.rst +++ b/docs/dockerhub.rst @@ -1,28 +1,50 @@ -=================== +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _docker-hub: + +******************************************************************** CK Docker Hub +******************************************************************** + +Why do I need this? =================== -------------------------------------- -Why do I need this? -------------------------------------- +To make things simpler, and bring Composable Kernel and its dependencies together, +docker images can be found on `Docker Hub `_. Docker images provide a complete image of the OS, the Composable Kernel library, and its dependencies in a single downloadable file. -To make our lives easier and bring Composable Kernel dependencies together, we recommend using -docker images that can be found on `Docker Hub `_. +Refer to `Docker Overview `_ for more information on Docker images and containers. -------------------------------------- -So what is Composable Kernel? -------------------------------------- +Which image is right for me? +============================ -Composable Kernel (CK) library aims to provide a programming model for writing performance critical -kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, -through general purpose kernel languages, like HIP C++. +The image naming includes information related to the docker image. +For example ``ck_ub20.04_rocm6.0`` indicates the following: -To get the CK library:: +* ``ck`` - made for running Composable Kernel; +* ``ub20.04`` - based on Ubuntu 20.04; +* ``rocm6.0`` - ROCm platform version 6.0. - git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git +Download a docker image suitable for your OS and ROCm release, run or start the docker container, and then resume the tutorial from this point. Use the ``docker pull`` command to download the file:: + + docker pull rocm/composable_kernel:ck_ub20.04_rocm6.0 -run a docker container:: +What is inside the image? +------------------------- + +The docker images have everything you need for running CK including: + +* `ROCm `_ +* `CMake `_ +* `Compiler `_ +* `Composable Kernel library `_ + +Running the docker container +============================ + +After downloading the docker image, you can start the container using one of a number of commands. Start with the ``docker run`` command as shown below:: docker run \ -it \ @@ -33,67 +55,47 @@ run a docker container:: rocm/composable_kernel:ck_ub20.04_rocm6.0 \ /bin/bash -and build the CK:: +After starting the bash shell, the docker container current folder is `~/workspace`. The library path is ``~/workspace/composable_kernel``. Navigate to the library to begin the tutorial as explained in :ref:`hello-world`: - mkdir build && cd build - # Need to specify target ID, example below is for gfx908 and gfx90a - cmake \ - -D CMAKE_PREFIX_PATH=/opt/rocm \ - -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ - -D CMAKE_CXX_FLAGS="-O3" \ - -D CMAKE_BUILD_TYPE=Release \ - -D GPU_TARGETS="gfx908;gfx90a" \ - .. +.. note:: -and:: + If your current folder is different from `${HOME}`, adjust the line ``-v ${HOME}:/root/workspace`` in the ``docker run`` command to fit your folder structure. - make -j examples tests +Stop and restart the docker image +================================= -To run all the test cases including tests and examples run:: +After finishing the tutorial, or just when you have completed your work session, you can close the docker container, or stop the docker container to restart it at another time. Closing the docker container means that it is still in the active state, and can be resumed from where you left it. Stopping the container closes it, and returns the image to its initial state. - make test +Use the ``Ctrl-D`` option to exit the container, while leaving it active, so you can return to the container in its current state to resume the tutorial, or pickup your project where you left off. -We can also run specific examples or tests like:: +To restart the active container use the ``docker exec`` command to specify the container name and options as follows:: - ./bin/example_gemm_xdl_fp16 - ./bin/test_gemm_fp16 + docker exec -it bash -For more details visit `CK github repository `_, -`CK examples `_, -`even more CK examples `_. +Where: -------------------------------------- -And what is inside? -------------------------------------- +* `exec` is the docker command +* `-it` is the interactive option for `exec` +* `` specifies an active container on the system +* `bash` specifies the command to run in the interactive shell -The docker images have everything you need for running CK including: +.. note:: -* `ROCm `_ -* `CMake `_ -* `Compiler `_ + You can use the ``docker container ls`` command to list the active containers on the system. -------------------------------------- -Which image is right for me? -------------------------------------- +To start a container from the image, use the ``docker start`` command:: -Let's take a look at the image naming, for example ``ck_ub20.04_rocm6.0``. The image specs are: + docker start -* ``ck`` - made for running Composable Kernel; -* ``ub20.04`` - based on Ubuntu 20.04; -* ``rocm6.0`` - ROCm platform version 6.0. +Then use the docker exec command as shown above to start the bash shell. -So just pick the right image for your project dependencies and you're all set. +Use the ``docker stop`` command to stop the container and restore the image to its initial state:: -------------------------------------- -DIY starts here -------------------------------------- + docker stop + +Editing the docker image +======================= -If you need to customize a docker image or just can't stop tinkering, feel free to adjust the +If you want to customize the docker image, edit the `Dockerfile `_ -for your needs. - -------------------------------------- -License -------------------------------------- - -CK is released under the MIT `license `_. +from the GitHub repository to suit your needs. diff --git a/docs/index.rst b/docs/index.rst index 8c4aaa2b3d..5dbd2eb033 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,57 +1,39 @@ -============================ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _composable-kernel: + +******************************************************************** Composable Kernel User Guide -============================ +******************************************************************** ------------- -Introduction ------------- +The Composable Kernel (CK) library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages like HIP C++. This document contains instructions for installing, using, and contributing to the Composable Kernel project. To learn more see :ref:`what-is-ck`. -This document contains instructions for installing, using, and contributing to Composable Kernel (CK). +The CK documentation is structured as follows: ------------ -Methodology ------------ +.. card:: Conceptual -Composable Kernel (CK) library aims to provide a programming model for writing performance critical -kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, -through general purpose kernel languages, like HIP C++. + * :ref:`what-is-ck` -CK utilizes two concepts to achieve performance portability and code maintainability: +.. card:: Installation -* A tile-based programming model -* Algorithm complexity reduction for complex ML operators, using innovative technique we call - "Tensor Coordinate Transformation". + * :ref:`docker-hub` -.. image:: data/ck_component.png - :alt: CK Components +.. card:: Tutorial --------------- -Code Structure --------------- + * :ref:`hello-world` -Current CK library are structured into 4 layers: +.. card:: API reference -* "Templated Tile Operators" layer -* "Templated Kernel and Invoker" layer -* "Instantiated Kernel and Invoker" layer -* "Wrapper for tensor transform operations" -* "Client API" layer + * :ref:`supported-primitives` + * :ref:`api-reference` + * :ref:`wrapper` -.. image:: data/ck_layer.png - :alt: CK Layers - -Documentation Roadmap -^^^^^^^^^^^^^^^^^^^^^ -The following is a list of CK documents in the suggested reading order: +.. card:: Contributing to CK -.. toctree:: - :maxdepth: 5 - :caption: Contents: - :numbered: + * :ref:`contributing-to` - tutorial_hello_world - dockerhub - wrapper - Supported_Primitives_Guide - API_Reference_Guide - Contributors_Guide +To contribute to the documentation refer to `Contributing to ROCm `_. + +You can find licensing information at the `Licensing `_ page. diff --git a/docs/license.md b/docs/license.md new file mode 100644 index 0000000000..43e471da0e --- /dev/null +++ b/docs/license.md @@ -0,0 +1,2 @@ +```{include} ../LICENSE.md +``` diff --git a/docs/license.rst b/docs/license.rst deleted file mode 100644 index ddb544496e..0000000000 --- a/docs/license.rst +++ /dev/null @@ -1,6 +0,0 @@ -======= -License -======= - -.. include:: ../LICENSE - :literal: diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index c37ba29cec..5780674624 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -1,10 +1,21 @@ -# Anywhere {branch} is used, the branch name will be substituted. -# These comments will also be removed. defaults: numbered: False - maxdepth: 6 root: index subtrees: -- caption: About - entries: - - file: license +- entries: + - file: what-is-ck.rst + title: What is Composable Kernel? + - file: dockerhub.rst + title: Docker Hub + - file: tutorial_hello_world.rst + title: Hello World Tutorial + - file: Supported_Primitives_Guide.rst + title: Supported Primitives + - file: API_Reference_Guide.rst + title: API Reference + - file: wrapper.rst + title: Wrapper + - file: Contributors_Guide.rst + title: Contributing to CK + - file: license.md + title: License diff --git a/docs/tutorial_hello_world.rst b/docs/tutorial_hello_world.rst index 6b8154d462..d89331e579 100644 --- a/docs/tutorial_hello_world.rst +++ b/docs/tutorial_hello_world.rst @@ -1,52 +1,44 @@ -=============== -CK Hello world -=============== +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation -------------------------------------- -Motivation -------------------------------------- +.. _hello-world: -This tutorial is aimed at engineers dealing with artificial intelligence and machine learning who -would like to optimize their pipelines and squeeze every performance drop by adding Composable -Kernel (CK) library to their projects. We would like to make the CK library approachable so -the tutorial is not based on the latest release and doesn't have all the bleeding edge features, -but it will be reproducible now and forever. +******************************************************************** +Hello World Tutorial +******************************************************************** -During this tutorial we will have an introduction to the CK library, we will build it and run some -examples and tests, so to say we will run a "Hello world" example. In future tutorials we will go -in depth and breadth and get familiar with other tools and ways to integrate CK into your project. +This tutorial is for engineers dealing with artificial intelligence and machine learning who +would like to optimize pipelines and improve performance using the Composable +Kernel (CK) library. This tutorial provides an introduction to the CK library. You will build the library and run some examples using a "Hello World" example. -------------------------------------- Description -------------------------------------- +=========== -Modern AI technology solves more and more problems in all imaginable fields, but crafting fast and -efficient workflows is still challenging. CK is one of the tools to make AI heavy lifting as fast -and efficient as possible. CK is a collection of optimized AI operator kernels and tools to create -new ones. The library has components required for majority of modern neural networks architectures -including matrix multiplication, convolution, contraction, reduction, attention modules, variety of -activation functions, fused operators and many more. +Modern AI technology solves more and more problems in a variety of fields, but crafting fast and +efficient workflows is still challenging. CK can make the AI workflow fast +and efficient. CK is a collection of optimized AI operator kernels with tools to create +new kernels. The library has components required for modern neural network architectures +including matrix multiplication, convolution, contraction, reduction, attention modules, a variety of activation functions, and fused operators. -So how do we (almost) reach the speed of light? CK acceleration abilities are based on: +CK library acceleration features are based on: -* Layered structure. -* Tile-based computation model. -* Tensor coordinate transformation. -* Hardware acceleration use. -* Support of low precision data types including fp16, bf16, int8 and int4. +* Layered structure +* Tile-based computation model +* Tensor coordinate transformation +* Hardware acceleration use +* Support of low precision data types including fp16, bf16, int8 and int4 -If you are excited and need more technical details and benchmarking results - read this awesome +If you need more technical details and benchmarking results read the following `blog post `_. -For more details visit our `github repository `_. +To download the library visit the `composable_kernel repository `_. -------------------------------------- Hardware targets -------------------------------------- +================ -CK library fully supports `gfx908` and `gfx90a` GPU architectures and only some operators are -supported for `gfx1030`. Let's check the hardware you have at hand and decide on the target -GPU architecture. +CK library fully supports `gfx908` and `gfx90a` GPU architectures, while only some operators are +supported for `gfx1030` devices. Check your hardware to determine the target GPU architecture. ========== ========= GPU Target AMD GPU @@ -59,47 +51,24 @@ gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6 There are also `cloud options `_ you can find if you don't have an AMD GPU at hand. -------------------------------------- Build the library -------------------------------------- +================= -First let's clone the library and rebase to the tested version:: +This tutorial is based on the use of docker images as explained in :ref:`docker-hub`. Download a docker image suitable for your OS and ROCm release, run or start the docker container, and then resume the tutorial from this point. - git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git - cd composable_kernel/ - git checkout tutorial_hello_world +.. note:: -To make our lives easier we prepared -`docker images `_ with all the necessary -dependencies. Pick the right image and create a container. In this tutorial we use -``rocm/composable_kernel:ck_ub20.04_rocm6.0`` image, it is based on Ubuntu 20.04 and -ROCm v6.0. + You can also `install ROCm `_ on your system, clone the `Composable Kernel repository `_ on GitHub, and use that to build and run the examples using the commands described below. -If your current folder is ``${HOME}``, start the docker container with:: - - docker run \ - -it \ - --privileged \ - --group-add sudo \ - -w /root/workspace \ - -v ${HOME}:/root/workspace \ - rocm/composable_kernel:ck_ub20.04_rocm6.0 \ - /bin/bash - -If your current folder is different from ``${HOME}``, adjust the line ``-v ${HOME}:/root/workspace`` -to fit your folder structure. - -Inside the docker container current folder is ``~/workspace``, library path is -``~/workspace/composable_kernel``, navigate to the library:: +Both the docker container and GitHub repository include the Composable Kernel library. Navigate to the library:: cd composable_kernel/ -Create and go to the ``build`` directory:: +Create and change to a ``build`` directory:: mkdir build && cd build -In the previous section we talked about target GPU architecture. Once you decide which one is right -for you, run CMake using the right ``GPU_TARGETS`` flag:: +The previous section discussed supported GPU architecture. Once you decide which hardware targets are needed, run CMake using the ``GPU_TARGETS`` flag:: cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ @@ -109,26 +78,25 @@ for you, run CMake using the right ``GPU_TARGETS`` flag:: -D BUILD_DEV=OFF \ -D GPU_TARGETS="gfx908;gfx90a;gfx1030" .. -If everything went well the CMake run will end up with:: +If everything goes well the CMake command will return:: -- Configuring done -- Generating done -- Build files have been written to: "/root/workspace/composable_kernel/build" -Finally, we can build examples and tests:: +Finally, you can build examples and tests:: make -j examples tests -If everything is smooth, you'll see:: +When complete you should see:: Scanning dependencies of target tests [100%] Built target tests ---------------------------- Run examples and tests ---------------------------- +====================== -Examples are listed as test cases as well, so we can run all examples and tests with:: +Examples are listed as test cases as well, so you can run all examples and tests with:: ctest @@ -136,38 +104,32 @@ You can check the list of all tests by running:: ctest -N -We can also run them separately, here is a separate example execution:: +You can also run examples separately as shown in the following example execution:: ./bin/example_gemm_xdl_fp16 1 1 1 -The arguments ``1 1 1`` mean that we want to run this example in the mode: verify results with CPU, -initialize matrices with integers and benchmark the kernel execution. You can play around with -these parameters and see how output and execution results change. +The arguments ``1 1 1`` mean that you want to run this example in the mode: verify results with CPU, initialize matrices with integers, and benchmark the kernel execution. You can play around with these parameters and see how output and execution results change. -If everything goes well and you have a device based on `gfx908` or `gfx90a` architecture you should see -something like:: +If you have a device based on `gfx908` or `gfx90a` architecture, and if the example runs as expected, you should see something like:: a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} - b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} + b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} - launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} - Warm up 1 time - Start running 10 times... - Perf: 1.10017 ms, 117.117 TFlops, 87.6854 GB/s, DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 + Perf: 1.08153 ms, 119.136 TFlops, 89.1972 GB/s, DeviceGemm_Xdl_CShuffle LoopScheduler: Interwave, PipelineVersion: v1 -Meanwhile, running it on a `gfx1030` device should result in:: +However, running it on a `gfx1030` device should result in the following:: a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem -But don't panic, some of the operators are supported on `gfx1030` architecture, so you can run a +Don't worry, some operators are supported on `gfx1030` architecture, so you can run a separate example like:: ./bin/example_gemm_dl_fp16 1 1 1 -and it should result in something nice similar to:: +and it should return something like:: a_m_k: dim 2, lengths {3840, 4096}, strides {1, 4096} b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} @@ -182,12 +144,9 @@ and it should result in something nice similar to:: .. note:: - There was a new CMake flag ``DL_KERNELS`` added in the latest versions of CK. If you use one of - the newest versions of the library and do not see the above results when running - ``example_gemm_dl_fp16``, it might be necessary to add ``-D DL_KERNELS=ON`` to your CMake command - in order to build the operators supported on the `gfx1030` architecture. + A new CMake flag ``DL_KERNELS`` has been added to the latest versions of CK. If you do not see the above results when running ``example_gemm_dl_fp16``, you might need to add ``-D DL_KERNELS=ON`` to your CMake command to build the operators supported on the `gfx1030` architecture. -We can also run a separate test:: +You can also run a separate test:: ctest -R test_gemm_fp16 @@ -198,13 +157,9 @@ If everything goes well you should see something like:: 100% tests passed, 0 tests failed out of 1 ------------ Summary ------------ +======= -In this tutorial we took the first look at the Composable Kernel library, built it on your system -and ran some examples and tests. Stay tuned, in the next tutorial we will run kernels with different -configs to find out the best one for your hardware and task. +In this tutorial you took the first look at the Composable Kernel library, built it on your system and ran some examples and tests. In the next tutorial you will run kernels with different configurations to find out the best one for your hardware and task. -P.S.: Don't forget to switch off the cloud instance if you have launched one, you can find better -ways to spend your money for sure! +P.S.: If you are running on a cloud instance, don't forget to switch off the cloud instance. diff --git a/docs/what-is-ck.rst b/docs/what-is-ck.rst new file mode 100644 index 0000000000..f0b51c48f8 --- /dev/null +++ b/docs/what-is-ck.rst @@ -0,0 +1,41 @@ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _what-is-ck: + +******************************************************************** +What is the Composable Kernel library +******************************************************************** + + +Methodology +=========== + +The Composable Kernel (CK) library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages like HIP C++. + +CK utilizes two concepts to achieve performance portability and code maintainability: + +* A tile-based programming model +* Algorithm complexity reduction for complex ML operators using an innovative technique called + "Tensor Coordinate Transformation". + +.. image:: data/ck_component.png + :alt: CK Components + + +Code Structure +============== + +The CK library is structured into 4 layers: + +* "Templated Tile Operators" layer +* "Templated Kernel and Invoker" layer +* "Instantiated Kernel and Invoker" layer +* "Client API" layer + +It also includes a simple wrapper component used to perform tensor transform operations more easily and with fewer lines of code. + +.. image:: data/ck_layer.png + :alt: CK Layers + \ No newline at end of file diff --git a/docs/wrapper.rst b/docs/wrapper.rst index da3a79eda8..c050f17caf 100644 --- a/docs/wrapper.rst +++ b/docs/wrapper.rst @@ -1,6 +1,12 @@ -=============== +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _wrapper: + +******************************************************************** Wrapper -=============== +******************************************************************** ------------------------------------- Description @@ -11,7 +17,7 @@ Description The wrapper is under development and its functionality is limited. -CK provides a lightweight wrapper for more complex operations implemented in +The CK library provides a lightweight wrapper for more complex operations implemented in the library. It allows indexing of nested layouts using a simple interface (avoiding complex descriptor transformations) and memory access (using Tensor). From 22db1e0865b3c55091f02dd0f94831414fd58962 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 5 Jan 2024 13:54:40 -0800 Subject: [PATCH 46/75] fix dockerfile syntax for test compilers (#1120) --- Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index c058c85723..4b72855bac 100644 --- a/Dockerfile +++ b/Dockerfile @@ -117,7 +117,7 @@ ENV compiler_commit=$compiler_commit RUN sh -c "echo compiler version = '$compiler_version'" RUN sh -c "echo compiler commit = '$compiler_commit'" -RUN if [ [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "amd-mainline-open" ] ] && [ "$compiler_commit" = "" ]; then \ +RUN if ( [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" = "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ @@ -125,7 +125,7 @@ RUN if [ [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "am else echo "using the release compiler"; \ fi -RUN if [ [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "amd-mainline-open" ] ] && [ "$compiler_commit" != "" ]; then \ +RUN if ( [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" != "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ From e699dbd8a3c0e75a980e1b9fcb9a359241e34309 Mon Sep 17 00:00:00 2001 From: raramakr <91213141+raramakr@users.noreply.github.com> Date: Tue, 9 Jan 2024 08:21:47 -0800 Subject: [PATCH 47/75] SWDEV-439954 - Use hard coded filename rather than using the macro __FILE__ for debug prints. (#1123) * SWDEV-439954 - Use hard coded filename rather than using the macro __FILE__ for debug prints. Hiptensor library is using the header files from CK. Hard coded ROCm path was getting embedded into the hiptensor library, since the header file was having the macro __FILE__. Replace the macro with filename. * fix syntax --------- Co-authored-by: illsilin --- include/ck/host_utility/hip_check_error.hpp | 28 ++++++++++--------- .../device/impl/device_contraction_utils.hpp | 10 ++++--- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/include/ck/host_utility/hip_check_error.hpp b/include/ck/host_utility/hip_check_error.hpp index 3e44faecb6..c0894f1d70 100644 --- a/include/ck/host_utility/hip_check_error.hpp +++ b/include/ck/host_utility/hip_check_error.hpp @@ -12,21 +12,23 @@ inline void hip_check_error(hipError_t x) if(x != hipSuccess) { std::ostringstream ss; - ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__ - << "in function: " << __func__; + ss << "HIP runtime error: " << hipGetErrorString(x) << ". " + << "hip_check_error.hpp" + << ": " << __LINE__ << "in function: " << __func__; throw std::runtime_error(ss.str()); } } -#define HIP_CHECK_ERROR(retval_or_funcall) \ - do \ - { \ - hipError_t _tmpVal = retval_or_funcall; \ - if(_tmpVal != hipSuccess) \ - { \ - std::ostringstream ostr; \ - ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \ - << hipGetErrorString(_tmpVal); \ - throw std::runtime_error(ostr.str()); \ - } \ +#define HIP_CHECK_ERROR(retval_or_funcall) \ + do \ + { \ + hipError_t _tmpVal = retval_or_funcall; \ + if(_tmpVal != hipSuccess) \ + { \ + std::ostringstream ostr; \ + ostr << "HIP Function Failed (" \ + << "hip_check_error.hpp" \ + << "," << __LINE__ << ") " << hipGetErrorString(_tmpVal); \ + throw std::runtime_error(ostr.str()); \ + } \ } while(0) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp index 0e14b40942..838305f187 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp @@ -35,15 +35,17 @@ auto CalculateMaxRead(const std::vector& lengths, const std::vector Date: Tue, 9 Jan 2024 09:43:08 -0800 Subject: [PATCH 48/75] Add an option to change the number of warm-up cycles and iterations. (#1124) * allow setting the number of warmup cycles and iterations for profiler * fix the gemm_splitk and grouped_gemm examples --- include/ck/host_utility/kernel_launch.hpp | 13 ++++--- .../include/profiler/profile_gemm_impl.hpp | 8 ++-- .../profiler/profile_gemm_splitk_impl.hpp | 11 ++++-- .../profiler/profile_grouped_gemm_impl.hpp | 11 ++++-- profiler/src/profile_gemm.cpp | 16 +++++++- profiler/src/profile_gemm_splitk.cpp | 17 ++++++++- profiler/src/profile_grouped_gemm.cpp | 38 +++++++++++++++---- test/gemm_split_k/test_gemm_splitk_util.hpp | 19 ++++++++-- test/grouped_gemm/test_grouped_gemm_util.hpp | 19 ++++++++-- 9 files changed, 119 insertions(+), 33 deletions(-) diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index cd00598b52..1ed7686e7f 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -30,7 +30,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, block_dim.y, block_dim.z); - printf("Warm up 1 time\n"); + printf("Warm up %d times\n", stream_config.cold_niters_); #endif // warm up for(int i = 0; i < stream_config.cold_niters_; ++i) @@ -103,14 +103,17 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, block_dim.y, block_dim.z); - printf("Warm up 1 time\n"); + printf("Warm up %d times\n", stream_config.cold_niters_); #endif // warm up preprocess(); - kernel<<>>(args...); - hip_check_error(hipGetLastError()); + for(int i = 0; i < stream_config.cold_niters_; ++i) + { + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + } - const int nrepeat = 10; + const int nrepeat = stream_config.nrepeat_; #if DEBUG_LOG printf("Start running %d times...\n", nrepeat); #endif diff --git a/profiler/include/profiler/profile_gemm_impl.hpp b/profiler/include/profiler/profile_gemm_impl.hpp index 08416d0146..586a356ecc 100644 --- a/profiler/include/profiler/profile_gemm_impl.hpp +++ b/profiler/include/profiler/profile_gemm_impl.hpp @@ -42,7 +42,9 @@ int profile_gemm_impl(int do_verification, int K, int StrideA, int StrideB, - int StrideC) + int StrideC, + int n_warmup, + int n_iter) { bool pass = true; @@ -165,8 +167,8 @@ int profile_gemm_impl(int do_verification, std::string op_name = op_ptr->GetTypeString(); - float avg_time = - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, 10, 50}); + float avg_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); std::size_t flop = std::size_t(2) * M * N * K; diff --git a/profiler/include/profiler/profile_gemm_splitk_impl.hpp b/profiler/include/profiler/profile_gemm_splitk_impl.hpp index eabb8e467d..6816d2c538 100644 --- a/profiler/include/profiler/profile_gemm_splitk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_splitk_impl.hpp @@ -42,7 +42,9 @@ bool profile_gemm_splitk_impl(int do_verification, int StrideA, int StrideB, int StrideC, - int KBatch) + int KBatch, + int n_warmup, + int n_iter) { bool pass = true; @@ -177,7 +179,8 @@ bool profile_gemm_splitk_impl(int do_verification, // re-init C to zero before profiling next kernel c_device_buf.SetZero(); - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); if(do_verification) { @@ -200,8 +203,8 @@ bool profile_gemm_splitk_impl(int do_verification, std::string op_name = op_ptr->GetTypeString(); - float ave_time = - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + float ave_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); std::size_t flop = std::size_t(2) * M * N * K; diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index fe7a397606..7f48ee0692 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -42,7 +42,9 @@ bool profile_grouped_gemm_impl(int do_verification, const std::vector& StrideAs, const std::vector& StrideBs, const std::vector& StrideCs, - int kbatch = 1) + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) { bool pass = true; @@ -261,7 +263,8 @@ bool profile_grouped_gemm_impl(int do_verification, for(std::size_t i = 0; i < gemm_descs.size(); i++) c_device_buf[i]->SetZero(); - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); if(do_verification) { @@ -307,8 +310,8 @@ bool profile_grouped_gemm_impl(int do_verification, pass = pass && instance_pass; } - float ave_time = - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + float ave_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); if(time_kernel) { diff --git a/profiler/src/profile_gemm.cpp b/profiler/src/profile_gemm.cpp index df243c96d6..0d6c5021f3 100644 --- a/profiler/src/profile_gemm.cpp +++ b/profiler/src/profile_gemm.cpp @@ -42,12 +42,15 @@ static void print_helper_msg() << "arg6: print tensor value (0: no; 1: yes)\n" << "arg7: time kernel (0: no, 1: yes)\n" << "arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n" + << "optional:\n" + << "arg14: number of warm-up cycles (default 1)\n" + << "arg15: number of iterations (default 10)\n" << std::endl; } int profile_gemm(int argc, char* argv[]) { - if(argc != 14) + if(argc != 14 && argc != 16) { print_helper_msg(); exit(1); @@ -68,6 +71,13 @@ int profile_gemm(int argc, char* argv[]) const int StrideB = std::stoi(argv[12]); const int StrideC = std::stoi(argv[13]); + int n_warmup = 1; + int n_iter = 10; + if(argc == 16) + { + n_warmup = std::stoi(argv[14]); + n_iter = std::stoi(argv[15]); + } using F32 = float; using F16 = ck::half_t; #ifdef CK_ENABLE_BF16 @@ -120,7 +130,9 @@ int profile_gemm(int argc, char* argv[]) K, (StrideA < 0) ? DefaultStrideA : StrideA, (StrideB < 0) ? DefaultStrideB : StrideB, - (StrideC < 0) ? DefaultStrideC : StrideC); + (StrideC < 0) ? DefaultStrideC : StrideC, + n_warmup, + n_iter); return pass ? 0 : 1; }; diff --git a/profiler/src/profile_gemm_splitk.cpp b/profiler/src/profile_gemm_splitk.cpp index d8fa3d872a..2a0467bc81 100644 --- a/profiler/src/profile_gemm_splitk.cpp +++ b/profiler/src/profile_gemm_splitk.cpp @@ -33,7 +33,7 @@ enum struct GemmDataType int profile_gemm_splitk(int argc, char* argv[]) { - if(argc != 15) + if(argc != 15 && argc != 17) { printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, " @@ -48,6 +48,9 @@ int profile_gemm_splitk(int argc, char* argv[]) printf("arg7: time kernel (0=no, 1=yes)\n"); printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); printf("arg14: split k into mulitiple batch\n"); + printf("optional:\n"); + printf("arg15: number of warm-up cycles (default 1)\n"); + printf("arg16: number of iterations (default 10)\n"); exit(1); } @@ -67,6 +70,14 @@ int profile_gemm_splitk(int argc, char* argv[]) const int StrideC = std::stoi(argv[13]); const int KBatch = std::stoi(argv[14]); + int n_warmup = 1; + int n_iter = 10; + if(argc == 17) + { + n_warmup = std::stoi(argv[15]); + n_iter = std::stoi(argv[16]); + } + using F32 = float; using F16 = ck::half_t; #if defined CK_ENABLE_FP8 @@ -117,7 +128,9 @@ int profile_gemm_splitk(int argc, char* argv[]) (StrideA < 0) ? DefaultStrideA : StrideA, (StrideB < 0) ? DefaultStrideB : StrideB, (StrideC < 0) ? DefaultStrideC : StrideC, - KBatch); + KBatch, + n_warmup, + n_iter); return pass ? 0 : 1; }; diff --git a/profiler/src/profile_grouped_gemm.cpp b/profiler/src/profile_grouped_gemm.cpp index 373b0c6729..25203d7b6c 100644 --- a/profiler/src/profile_grouped_gemm.cpp +++ b/profiler/src/profile_grouped_gemm.cpp @@ -69,7 +69,10 @@ int profile_grouped_gemm(int argc, char* argv[]) << "arg7: time kernel (0=n0, 1=yes)\n" << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " "64,64 64,64 128,128)\n" - << "arg15: kbatch value (default 4)\n" + << "arg15: kbatch value (default 1)\n" + << "optional:\n" + << "arg16: number of warm-up cycles (default 1)\n" + << "arg17: number of iterations (default 10)\n" << std::endl; exit(1); @@ -90,6 +93,15 @@ int profile_grouped_gemm(int argc, char* argv[]) const auto StrideBs = argToIntArray(argv[12]); const auto StrideCs = argToIntArray(argv[13]); const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1; + + int n_warmup = 1; + int n_iter = 10; + if(argc == 17) + { + n_warmup = std::stoi(argv[16]); + n_iter = std::stoi(argv[17]); + } + #ifdef CK_ENABLE_FP16 if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { @@ -109,7 +121,9 @@ int profile_grouped_gemm(int argc, char* argv[]) StrideAs, StrideBs, StrideCs, - kbatch); + kbatch, + n_warmup, + n_iter); } else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) { @@ -129,7 +143,9 @@ int profile_grouped_gemm(int argc, char* argv[]) StrideAs, StrideBs, StrideCs, - kbatch); + kbatch, + n_warmup, + n_iter); } else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) { @@ -149,7 +165,9 @@ int profile_grouped_gemm(int argc, char* argv[]) StrideAs, StrideBs, StrideCs, - kbatch); + kbatch, + n_warmup, + n_iter); } else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) { @@ -169,7 +187,9 @@ int profile_grouped_gemm(int argc, char* argv[]) StrideAs, StrideBs, StrideCs, - kbatch); + kbatch, + n_warmup, + n_iter); } else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { @@ -189,7 +209,9 @@ int profile_grouped_gemm(int argc, char* argv[]) StrideAs, StrideBs, StrideCs, - kbatch); + kbatch, + n_warmup, + n_iter); } else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN) { @@ -209,7 +231,9 @@ int profile_grouped_gemm(int argc, char* argv[]) StrideAs, StrideBs, StrideCs, - kbatch); + kbatch, + n_warmup, + n_iter); } else { diff --git a/test/gemm_split_k/test_gemm_splitk_util.hpp b/test/gemm_split_k/test_gemm_splitk_util.hpp index 8243747a69..99d9d5e832 100644 --- a/test/gemm_split_k/test_gemm_splitk_util.hpp +++ b/test/gemm_split_k/test_gemm_splitk_util.hpp @@ -60,7 +60,9 @@ class TestGemmSplitK : public testing::Test const int StrideA, const int StrideB, const int StrideC, - int kbatch = 1) + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) { bool pass = ck::profiler::profile_gemm_splitk_impl( - verify_, init_method_, log_, bench_, M, N, K, StrideA, StrideB, StrideC, kbatch); + CLayout>(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + kbatch, + n_warmup, + n_iter); EXPECT_TRUE(pass); } }; diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index 04b31dcc91..50f423ada3 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -63,7 +63,9 @@ class TestGroupedGemm : public testing::TestWithParam const std::vector& StrideAs, const std::vector& StrideBs, const std::vector& StrideCs, - int kbatch = 1) + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) { bool pass = ck::profiler::profile_grouped_gemm_impl float, ALayout, BLayout, - ELayout>( - verify_, init_method_, log_, bench_, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch); + ELayout>(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); EXPECT_TRUE(pass); } }; From 0ce417269dbeedf9852e7c310b46900c882bca90 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 11 Jan 2024 11:27:03 -0700 Subject: [PATCH 49/75] Bump sphinxcontrib-bibtex from 2.6.1 to 2.6.2 in /docs/sphinx (#1129) Bumps [sphinxcontrib-bibtex](https://github.com/mcmtroffaes/sphinxcontrib-bibtex) from 2.6.1 to 2.6.2. - [Changelog](https://github.com/mcmtroffaes/sphinxcontrib-bibtex/blob/develop/CHANGELOG.rst) - [Commits](https://github.com/mcmtroffaes/sphinxcontrib-bibtex/compare/2.6.1...2.6.2) --- updated-dependencies: - dependency-name: sphinxcontrib-bibtex dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 6bcd2c43de..5b9f8f197f 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ rocm-docs-core==0.30.3 -sphinxcontrib-bibtex==2.6.1 +sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index e705e35e13..7bf6844ff6 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -149,7 +149,7 @@ sphinx-notfound-page==0.8.3 # via rocm-docs-core sphinxcontrib-applehelp==1.0.4 # via sphinx -sphinxcontrib-bibtex==2.6.1 +sphinxcontrib-bibtex==2.6.2 # via -r requirements.in sphinxcontrib-devhelp==1.0.2 # via sphinx From 636a31015a801e09c7f67f3a0544e8db34b58269 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Jan 2024 09:09:13 -0700 Subject: [PATCH 50/75] Bump rocm-docs-core from 0.30.3 to 0.31.0 in /docs/sphinx (#1131) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.30.3 to 0.31.0. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.30.3...v0.31.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 5b9f8f197f..23a4c4bb91 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.30.3 +rocm-docs-core==0.31.0 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 7bf6844ff6..1e5e688dac 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -113,7 +113,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.30.3 +rocm-docs-core==0.31.0 # via -r requirements.in six==1.16.0 # via From e6d099c8309576c73ed3129d4e87ebc126c9a03e Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 15 Jan 2024 09:11:45 -0800 Subject: [PATCH 51/75] Add cppcheck to CK CI. (#1125) * add cppcheck to the CK CI * fix the path to CK source for cppcheck * fix the path to CK source for cppcheck one more time * fix the path to CK source for cppcheck third time * change the path to ck_cppcheck.log * install latest cppcheck from source * fix bug in ck.hpp and use 20 threads for cppcheck * create a switch to turn cppckeck on and off in CI --- Dockerfile | 7 ++++++- Jenkinsfile | 34 +++++++++++++++++++++++++++++++++- include/ck/ck.hpp | 2 +- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 4b72855bac..a805285a77 100644 --- a/Dockerfile +++ b/Dockerfile @@ -74,7 +74,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- apt-get clean && \ rm -rf /var/lib/apt/lists/* -#Install latest version of cmake +#Install ninja build tracing tools RUN wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip RUN gunzip /usr/local/bin/ninja.gz RUN chmod a+x /usr/local/bin/ninja @@ -82,6 +82,11 @@ RUN git clone https://github.com/nico/ninjatracing.git # Update the cmake to the latest version RUN pip install --upgrade cmake==3.27.5 +#Install latest cppcheck +RUN git clone https://github.com/danmar/cppcheck.git && \ + cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . +WORKDIR / + # Setup ubsan environment to printstacktrace RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer ENV UBSAN_OPTIONS=print_stacktrace=1 diff --git a/Jenkinsfile b/Jenkinsfile index 268cc7606f..e333a35ecd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -304,7 +304,7 @@ def buildHipClangJob(Map conf=[:]){ gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { - timeout(time: 5, unit: 'HOURS') + timeout(time: 20, unit: 'HOURS') { cmake_build(conf) } @@ -709,6 +709,10 @@ pipeline { name: "USE_SCCACHE", defaultValue: true, description: "Use the sccache for building CK (default: ON)") + booleanParam( + name: "RUN_CPPCHECK", + defaultValue: false, + description: "Run the cppcheck static analysis (default: OFF)") } environment{ dbuser = "${dbuser}" @@ -735,7 +739,35 @@ pipeline { } stage("Static checks") { parallel{ + stage('Clang Format and Cppcheck') { + when { + beforeAgent true + expression { params.RUN_CPPCHECK.toBoolean() } + } + agent{ label rocmnode("nogpu") } + environment{ + execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \ + -o -not -path \'*.git*\' -iname \'*.hpp\' \ + -o -not -path \'*.git*\' -iname \'*.cpp\' \ + -o -iname \'*.h.in\' \ + -o -iname \'*.hpp.in\' \ + -o -iname \'*.cpp.in\' \ + -o -iname \'*.cl\' \ + | grep -v 'build/' \ + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \ + /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include --file-filter=*.cpp --enable=all --output-file=ck_cppcheck.log" + } + steps{ + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) + archiveArtifacts "build/ck_cppcheck.log" + cleanWs() + } + } stage('Clang Format') { + when { + beforeAgent true + expression { !params.RUN_CPPCHECK.toBoolean() } + } agent{ label rocmnode("nogpu") } environment{ execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \ diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index a94057be4a..88efb0277b 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -218,7 +218,7 @@ // denorm test fix, required to work around dissue #ifndef CK_WORKAROUND_DENORM_FIX #define CK_WORKAROUND_DENORM_FIX 0 -#elif +#else // enable only on MI200 #define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) #endif // CK_WORKAROUND_DENORM_FIX From c1b5b58192aebb6c3df782a27f36729b85fb9bbb Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 16 Jan 2024 07:55:18 -0800 Subject: [PATCH 52/75] add code owners (#1132) --- .github/CODEOWNERS | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 30f0dedd8d..11648bfd27 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,3 +1,4 @@ +* @zjing14 @asroy @junliume @illsilin @carlushuang # Documentation files docs/* @saadrahim @LisaDelaney *.md @saadrahim @LisaDelaney From 402a930a4a72d913794a57028fe37f8cf274c569 Mon Sep 17 00:00:00 2001 From: randyh62 <42045079+randyh62@users.noreply.github.com> Date: Tue, 16 Jan 2024 09:00:37 -0800 Subject: [PATCH 53/75] Randyh docfix (#1130) * Update LICENSE update to 2024 * Update index.rst change license.md to license.html * fix syntax --------- Co-authored-by: illsilin --- LICENSE | 2 +- docs/index.rst | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/LICENSE b/LICENSE index e03fddaf78..581b5efde5 100644 --- a/LICENSE +++ b/LICENSE @@ -7,7 +7,7 @@ Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) SPDX-License-Identifier: MIT -Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/docs/index.rst b/docs/index.rst index 5dbd2eb033..8ae4ce3a22 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -34,6 +34,6 @@ The CK documentation is structured as follows: * :ref:`contributing-to` -To contribute to the documentation refer to `Contributing to ROCm `_. +To contribute to the documentation refer to `Contributing to ROCm `_. -You can find licensing information at the `Licensing `_ page. +You can find licensing information on the `Licensing `_ page. From 38882d8ab595b76fc1d328c1079884471bc63963 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 18 Jan 2024 17:20:40 -0800 Subject: [PATCH 54/75] add Adam to code owners (#1136) --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 11648bfd27..e4d0d47a2e 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,4 +1,4 @@ -* @zjing14 @asroy @junliume @illsilin @carlushuang +* @zjing14 @asroy @junliume @illsilin @carlushuang @aosewski # Documentation files docs/* @saadrahim @LisaDelaney *.md @saadrahim @LisaDelaney From 7e4eb4b800b7bec8adb9a1a766f7aba1557e8aa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 19 Jan 2024 11:29:00 +0100 Subject: [PATCH 55/75] Add optimized copy to ck wrapper (#1126) * Add optimized copy to ck wrapper * Example optimizations * Fixes * Move img2col test to client example * Refactor example * Fix docs * Fixes * Fix * Fixes * Fixes * Fixes * Fixes * Fixes --------- Co-authored-by: zjing14 --- CHANGELOG.md | 17 +- .../25_tensor_transforms/tensor_transform.cpp | 150 ----- .../CMakeLists.txt | 4 +- .../tensor_transform_using_wrapper.cpp | 2 +- client_example/25_wrapper/wrapper_img2col.cpp | 180 ++++++ docs/wrapper.rst | 8 +- .../ck/utility/is_known_at_compile_time.hpp | 8 +- include/ck/wrapper/layout.hpp | 192 +++++-- include/ck/wrapper/operations/copy.hpp | 140 ++++- include/ck/wrapper/tensor.hpp | 511 ++++++++++-------- include/ck/wrapper/utils/layout_utils.hpp | 81 ++- include/ck/wrapper/utils/tensor_partition.hpp | 376 +++++-------- include/ck/wrapper/utils/tensor_utils.hpp | 111 ++-- .../cpu/reference_image_to_column.hpp | 3 +- test/wrapper/test_copy.cpp | 79 +-- test/wrapper/test_partition.cpp | 89 ++- test/wrapper/test_tensor.cpp | 23 +- 17 files changed, 1109 insertions(+), 865 deletions(-) delete mode 100644 client_example/25_tensor_transforms/tensor_transform.cpp rename client_example/{25_tensor_transforms => 25_wrapper}/CMakeLists.txt (55%) rename client_example/{25_tensor_transforms => 25_wrapper}/tensor_transform_using_wrapper.cpp (98%) create mode 100644 client_example/25_wrapper/wrapper_img2col.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index abca69142e..12cc4363de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,21 @@ Full documentation for Composable Kernel is not yet available. -## (Unreleased) CK for ROCm 6.0.0 +## (Unreleased) CK + +### Fixes +None + +### Optimizations +None + +### Additions +- Introduce wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126) + +### Changes +None + +## CK for ROCm 6.0.0 ### Fixes - Fixed a hazard associated with inline v_dot (#808) @@ -19,7 +33,6 @@ None - Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) - Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) - Support for Batched Gemm DL (#732) -- Introduce wrapper sublibrary (limited functionality). (#1071, #1098, #1108) ### Changes - Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) diff --git a/client_example/25_tensor_transforms/tensor_transform.cpp b/client_example/25_tensor_transforms/tensor_transform.cpp deleted file mode 100644 index 41ceec1cb5..0000000000 --- a/client_example/25_tensor_transforms/tensor_transform.cpp +++ /dev/null @@ -1,150 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" - -#include "ck/utility/number.hpp" -#include "ck/utility/tuple.hpp" -#include "ck/utility/sequence.hpp" - -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_description/multi_index_transform_helper.hpp" - -static constexpr auto I0 = ck::Number<0>{}; -static constexpr auto I1 = ck::Number<1>{}; -static constexpr auto I2 = ck::Number<2>{}; - -using DataType = int; - -template -void Print1d(const Desc& desc) -{ - std::cout << "Print1d" << std::endl; - for(ck::index_t w = 0; w < desc.GetLength(I0); w++) - { - std::cout << desc.CalculateOffset(ck::make_tuple(w)) << " "; - } - std::cout << std::endl; -} - -template -void Print2d(const Desc& desc) -{ - std::cout << "Print2d" << std::endl; - for(ck::index_t h = 0; h < desc.GetLength(I0); h++) - { - for(ck::index_t w = 0; w < desc.GetLength(I1); w++) - { - std::cout << desc.CalculateOffset(ck::make_tuple(h, w)) << " "; - } - std::cout << std::endl; - } -} - -template -void Print3dCustom(const Desc& desc) -{ - std::cout << "Print3dCustom" << std::endl; - for(ck::index_t d = 0; d < desc.GetLength(I0); d++) - { - for(ck::index_t h = 0; h < desc.GetLength(I1); h++) - { - for(ck::index_t w = 0; w < desc.GetLength(I2); w++) - { - std::cout << desc.CalculateOffset(ck::make_tuple(d, h, w)) << " "; - } - std::cout << std::endl; - } - std::cout << std::endl; - } -} - -int main() -{ - // Tensor descriptor traverse in row-major (need to reverse dims) - std::cout << "Note: Tensor descriptor traverse in row-major" << std::endl; - // Basic descriptor 0, 1, 2, ... 30, 31 - // (dims:4,8 strides:1,4) - const auto desc_4x8_s1x4 = - ck::make_naive_tensor_descriptor(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}), - ck::make_tuple(ck::Number<1>{}, ck::Number<4>{})); - std::cout << "dims:4,8 strides:1,4" << std::endl; - Print2d(desc_4x8_s1x4); - - using Cord1x1Type = ck::Tuple, ck::Number<1>>; - constexpr ck::index_t offset_1x1 = desc_4x8_s1x4.CalculateOffset(Cord1x1Type{}); - std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl; - - // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) - // dims:4,(2,4) strides:2,(1,8) - const auto desc_4x2x4_s2x1x8 = - ck::make_naive_tensor_descriptor(ck::make_tuple(4, 2, 4), ck::make_tuple(2, 1, 8)); - // Transform to 2d (column-major, need to to reverse dims) - const auto desc_4x2x4_s2x1x8_merged = ck::transform_tensor_descriptor( - desc_4x2x4_s2x1x8, - ck::make_tuple(ck::make_pass_through_transform(4), - ck::make_merge_transform(ck::make_tuple(4, 2))), - ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<2, 1>{}), - ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); - - std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; - Print2d(desc_4x2x4_s2x1x8_merged); - - // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) - // dims:(2,2),(2,4) strides:((1,4),(2,8) - const auto desc_2x2x2x4_s1x4x2x8 = - ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8)); - // Transform to 2d - const auto desc_2x2x2x4_s1x4x2x8_double_merged_2d = ck::transform_tensor_descriptor( - desc_2x2x2x4_s1x4x2x8, - ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)), - ck::make_merge_transform(ck::make_tuple(4, 2))), - ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<3, 2>{}), - ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); - // Transform to 3d - const auto desc_2x2x2x4_s1x4x2x8_double_merged_3d = ck::transform_tensor_descriptor( - desc_2x2x2x4_s1x4x2x8, - ck::make_tuple(ck::make_pass_through_transform(2), - ck::make_pass_through_transform(2), - ck::make_merge_transform(ck::make_tuple(4, 2))), - ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<3, 2>{}), - ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{})); - - std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl; - Print2d(desc_2x2x2x4_s1x4x2x8_double_merged_2d); - Print3dCustom(desc_2x2x2x4_s1x4x2x8_double_merged_3d); - - // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) - // dims:((2,2),2),4 strides:((1,4),2),8 - // Transform to 2d - const auto desc_2x2x2x4_s1x4x2x8_nested = - ck::make_naive_tensor_descriptor(ck::make_tuple(2, 2, 2, 4), ck::make_tuple(1, 4, 2, 8)); - const auto desc_2x2x2x4_s1x4x2x8_nested_merged_3d = ck::transform_tensor_descriptor( - desc_2x2x2x4_s1x4x2x8_nested, - ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 2)), - ck::make_pass_through_transform(2), - ck::make_pass_through_transform(4)), - ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}, ck::Sequence<3>{}), - ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{}, ck::Sequence<2>{})); - const auto desc_2x2x2x4_s1x4x2x8_nested_merged_1d = ck::transform_tensor_descriptor( - desc_2x2x2x4_s1x4x2x8_nested, - ck::make_tuple(ck::make_merge_transform(ck::make_tuple(4, 2, 2, 2))), - ck::make_tuple(ck::Sequence<3, 2, 1, 0>{}), - ck::make_tuple(ck::Sequence<0>{})); - const auto desc_2x2x2x4_s1x4x2x8_nested_merged_2d = ck::transform_tensor_descriptor( - desc_2x2x2x4_s1x4x2x8_nested_merged_3d, - ck::make_tuple(ck::make_merge_transform(ck::make_tuple(2, 4)), - ck::make_pass_through_transform(4)), - ck::make_tuple(ck::Sequence<1, 0>{}, ck::Sequence<2>{}), - ck::make_tuple(ck::Sequence<0>{}, ck::Sequence<1>{})); - - std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl; - Print1d(desc_2x2x2x4_s1x4x2x8_nested_merged_1d); - Print2d(desc_2x2x2x4_s1x4x2x8_nested_merged_2d); - Print3dCustom(desc_2x2x2x4_s1x4x2x8_nested_merged_3d); - - return 0; -} diff --git a/client_example/25_tensor_transforms/CMakeLists.txt b/client_example/25_wrapper/CMakeLists.txt similarity index 55% rename from client_example/25_tensor_transforms/CMakeLists.txt rename to client_example/25_wrapper/CMakeLists.txt index d1543fb0ef..eb3be0e6c8 100644 --- a/client_example/25_tensor_transforms/CMakeLists.txt +++ b/client_example/25_wrapper/CMakeLists.txt @@ -1,4 +1,4 @@ -add_executable(client_tensor_transform tensor_transform.cpp) -target_link_libraries(client_tensor_transform PRIVATE composable_kernel::device_other_operations) add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp) target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations) +add_executable(client_wrapper_img2col wrapper_img2col.cpp) +target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp b/client_example/25_wrapper/tensor_transform_using_wrapper.cpp similarity index 98% rename from client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp rename to client_example/25_wrapper/tensor_transform_using_wrapper.cpp index de9fcde0b4..4b25d85e2d 100644 --- a/client_example/25_tensor_transforms/tensor_transform_using_wrapper.cpp +++ b/client_example/25_wrapper/tensor_transform_using_wrapper.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/client_example/25_wrapper/wrapper_img2col.cpp b/client_example/25_wrapper/wrapper_img2col.cpp new file mode 100644 index 0000000000..35074be4c1 --- /dev/null +++ b/client_example/25_wrapper/wrapper_img2col.cpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" + +static constexpr ck::index_t NumDimSpatial = 3; +using DataType = float; +using InputLayout = ck::tensor_layout::convolution::NDHWGC; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +// Test copy from Global to Global through LDS and VGPR +template +__global__ void DeviceImageToColumnPad0(InputTensor input_tensor, + OutputTensor output_tensor, + const BlockShape tile_shape, + const ThreadLayoutShape thread_layout) +{ + const ck::index_t block_idx = static_cast(blockIdx.x); + + // Get local tiles for global memory + auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx); + auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx); + + // Get partition per thread + const auto input_local_partition = + ck::wrapper::make_local_partition(input_local_tile, thread_layout, threadIdx.x); + auto output_local_partition = + ck::wrapper::make_local_partition(output_local_tile, thread_layout, threadIdx.x); + + // Perform copy + using DimAccessOrder = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t vector_dim = 1; + constexpr ck::index_t scalar_per_vector = 4; + ck::wrapper::copy(input_local_partition, + output_local_partition); +} + +void PerformImageToColumnPad0(const ck::index_t G, + const ck::index_t N, + const ck::index_t Di, + const ck::index_t Hi, + const ck::index_t Wi, + const ck::index_t Do, + const ck::index_t Ho, + const ck::index_t Wo, + const ck::index_t C, + const ck::index_t Z, + const ck::index_t Y, + const ck::index_t X, + std::array filter_strides, + std::array filter_dilations) +{ + const ck::index_t ZYXC = Z * Y * X * C; + const ck::index_t GC = G * C; + + // shape: (G, (Wo, Ho, Do, N)), (C, X, Y, Z)) + const auto shape = ck::make_tuple(ck::make_tuple(G, ck::make_tuple(Wo, Ho, Do, N)), + ck::make_tuple(C, X, Y, Z)); + const auto in_strides = + ck::make_tuple(ck::make_tuple(C, + ck::make_tuple(filter_strides[2] * GC, + filter_strides[1] * Wi * GC, + filter_strides[0] * Hi * Wi * GC, + Di * Hi * Wi * GC)), + ck::make_tuple(1, + filter_dilations[2] * GC, + filter_dilations[1] * Wi * GC, + filter_dilations[0] * Hi * Wi * GC)); + const auto in_layout = ck::wrapper::make_layout(shape, in_strides); + + const auto out_strides = ck::make_tuple( + ck::make_tuple( + ZYXC, + ck::make_tuple(ZYXC * G, Wo * ZYXC * G, Ho * Wo * ZYXC * G, Do * Ho * Wo * ZYXC * G)), + ck::make_tuple(1, C, X * C, Y * X * C)); + const auto out_layout = ck::wrapper::make_layout(shape, out_strides); + + const ck::index_t input_size = N * Di * Hi * Wi * GC; + // Global memory buffers + SimpleDeviceMem in_buf(input_size * sizeof(DataType)); + SimpleDeviceMem out_buf(ck::wrapper::size(out_layout) * sizeof(DataType)); + + // User can choose appropriate number of threads and sizes per block + const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}); + // This example doesn't support padding, user should select tile sizes + // which divides the shape completely + const auto tile_shape = ck::make_tuple(ck::Number<32>{}, ck::Number<64>{}); + + // Create buffers for global memory + auto input_tensor_global = ck::wrapper::make_tensor( + static_cast(in_buf.GetDeviceBuffer()), in_layout); + auto output_tensor_global = ck::wrapper::make_tensor( + static_cast(out_buf.GetDeviceBuffer()), out_layout); + + const ck::index_t grid_size = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout), + ck::wrapper::size<0>(tile_shape)) * + ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout), + ck::wrapper::size<1>(tile_shape)); + + const auto kernel = DeviceImageToColumnPad0; + const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, + kernel, + dim3(grid_size), + dim3(ck::wrapper::size(thread_layout)), + 0, + input_tensor_global, + output_tensor_global, + tile_shape, + thread_layout); + + std::size_t num_btype = G * N * Do * Ho * Wo * ZYXC * 2 * sizeof(DataType); + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << std::endl; +} + +int main(int argc, char* argv[]) +{ + constexpr ck::index_t G = 4; // number of groups + constexpr ck::index_t N = 32; // batch + constexpr ck::index_t C = 64; // input channel (per group) + constexpr ck::index_t Z = 3; // filter D + constexpr ck::index_t Y = 3; // filter H + constexpr ck::index_t X = 3; // filter W + constexpr ck::index_t Di = 9; // input D + constexpr ck::index_t Hi = 9; // input H + constexpr ck::index_t Wi = 7; // input W + constexpr ck::index_t Do = 7; // output D + constexpr ck::index_t Ho = 7; // output H + constexpr ck::index_t Wo = 5; // output W + PerformImageToColumnPad0(G, + N, + Di, + Hi, + Wi, + Do, + Ho, + Wo, + C, + Z, + Y, + X, + {1, 1, 1} /*filter_strides*/, + {1, 1, 1} /*filter_dilations*/); + return 0; +} diff --git a/docs/wrapper.rst b/docs/wrapper.rst index c050f17caf..79b6c75580 100644 --- a/docs/wrapper.rst +++ b/docs/wrapper.rst @@ -18,8 +18,7 @@ Description The CK library provides a lightweight wrapper for more complex operations implemented in -the library. It allows indexing of nested layouts using a simple interface -(avoiding complex descriptor transformations) and memory access (using Tensor). +the library. Example: @@ -54,6 +53,11 @@ Output:: 1 5 9 13 17 21 25 29 2 6 10 14 18 22 26 30 + +Advanced examples: + +* `Image to column `_ + ------------------------------------- Layout ------------------------------------- diff --git a/include/ck/utility/is_known_at_compile_time.hpp b/include/ck/utility/is_known_at_compile_time.hpp index 2cafc3e6f2..0916e4604e 100644 --- a/include/ck/utility/is_known_at_compile_time.hpp +++ b/include/ck/utility/is_known_at_compile_time.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,6 +19,12 @@ struct is_known_at_compile_time static constexpr bool value = false; }; +template <> +struct is_known_at_compile_time +{ + static constexpr bool value = false; +}; + template <> struct is_known_at_compile_time { diff --git a/include/ck/wrapper/layout.hpp b/include/ck/wrapper/layout.hpp index 1643eb7383..39b5c79c67 100644 --- a/include/ck/wrapper/layout.hpp +++ b/include/ck/wrapper/layout.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,22 +14,28 @@ namespace wrapper { * \tparam Shape Tuple of Number<> (for compile-time layout) or index_t * (dynamic layout). It is possible to pass nested shapes * (e.g. ((4, 2), 2)), nested dimensions are merged. - * \tparam UnnestedDescriptorType Tensor descriptor for unnested shape dims. + * \tparam UnrolledDescriptorType Tensor descriptor for unnested shape dims. */ -template +template struct Layout { private: static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - // Generate default idxs tuple (idx with all merged nested shapes) + /** + * \brief Generate default indices tuple (idx with all merged nested shapes) + * + * \param shape Shape to align. + * \return Multi idx tuple with zeros. + */ template - __host__ __device__ constexpr static auto GenerateDefaultIdxsTuple(const Tuple&) + __host__ __device__ constexpr static auto + GenerateDefaultIdxsTuple([[maybe_unused]] const Tuple& shape) { return generate_tuple( [&](auto) { - if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime()) + if constexpr(!remove_cvref_t::IsKnownAtCompileTime()) { // runtime layout return index_t(0); @@ -43,11 +49,18 @@ struct Layout Number::Size()>{}); } - // Generate LowerDims in Compile-time for MergeTrasform using passed Type - // If element of Tuple is also tuple, then merge (generate sequence for merge) - // If tuple is element, then pass through (sequence with one element) + /** + * \brief Generate lower dims in compile-time for the Merge transform using + * provided type. If element of nested Tuple is also a tuple, then + * merge (generate sequence for merge). If tuple is element, then pass + * through (sequence with one element). + * + * \param shape Shape to align. + * \return LowerDims for MergeTrasform. + */ template - __host__ __device__ constexpr static auto GenerateLowerDim(const Tuple&) + __host__ __device__ constexpr static auto + GenerateLowerDim([[maybe_unused]] const Tuple& shape) { if constexpr(Idx::value == 0) { @@ -87,11 +100,17 @@ struct Layout } } - // Iterate over nested tuples in shape - // Unroll nested tuples to align Tuple to Tuple - // Example idx: (1, 1), 1, 1 - // Example shape: (2, (2, 2)), 2, (2, 2) - // Unrolled shape: 2, (2, 2), 2, (2, 2) + /** + * \brief Iterate over the nested tuples in the shape. + * Unroll nested tuples to align Tuple to Tuple + * Example idx: (1, 1), 1, 1 + * Example shape: (2, (2, 2)), 2, (2, 2) + * Unrolled shape: 2, (2, 2), 2, (2, 2) + * + * \param shape Layout shape. + * \param idx Idx to align. + * \return Algined shape. + */ template __host__ __device__ constexpr static auto AlignShapeToIdx(const Tuple& shape, const Tuple& idx) @@ -126,6 +145,13 @@ struct Layout } } + /** + * \brief Merge descriptor to 1D. + * + * \param shape Layout shape. + * \param desc Descriptor to merge. + * \return 1D descriptor. + */ template __host__ __device__ constexpr static auto MakeMerge1d(const Tuple& shape, const DescriptorToMerge& desc) @@ -137,18 +163,41 @@ struct Layout const auto lower_dims = make_tuple(MergeElemsSequence::Reverse()); const auto upper_dims = make_tuple(Sequence<0>{}); // Merge to 1d - return transform_tensor_descriptor( - desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims); + if constexpr(!remove_cvref_t::IsKnownAtCompileTime()) + { + return transform_tensor_descriptor( + desc, make_tuple(make_merge_transform(merge_elems)), lower_dims, upper_dims); + } + else + { + // If the descriptor is known at the compilation time, + // use `make_merge_transform_v1_carry_check` because it doesn't use + // memcpy. + return transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform_v1_carry_check(merge_elems)), + lower_dims, + upper_dims); + } } - // Merge nested shape dims when corresponding index is also nested. - // Input desc shape: 2, 2, 2, 2, 2, 2 - // Example idx: 1, 1, 1, 1 - // Example shape: 2, (2, 2), 2, (2, 2) - // Merged shape: 2, 4, 2, 4 + /** + * \brief Merge nested shape dims when corresponding index is also merged. + * Input desc shape: 2, 2, 2, 2, 2, 2 + * Example idx: 1, 1, 1, (1, 1) + * Example shape: 2, (2, 2), 2, (2, 2) + * Merged shape: 2, 4, 2, 2, 2 + * + * \param shape Layout shape. + * \param idxs Indexes to align descriptor. + * \param desc Descriptor to merge. + * \return Aligned descriptor to idx. + */ template - __host__ __device__ constexpr static auto CreateMergedDescriptor( - const Tuple& shape, const Tuple&, DescriptorToMerge& desc) + __host__ __device__ constexpr static auto + CreateMergedDescriptor(const Tuple& shape, + [[maybe_unused]] const Tuple& idxs, + DescriptorToMerge& desc) { const auto transforms = generate_tuple( [&](auto i) { @@ -160,7 +209,17 @@ struct Layout // If shape element is tuple and idx element is Number, then merge // Unroll and reverse tuple to traverse column-major const auto merge_elems = TupleReverse(UnrollNestedTuple(shape.At(i))); - return make_merge_transform(merge_elems); + if constexpr(!remove_cvref_t::IsKnownAtCompileTime()) + { + return make_merge_transform(merge_elems); + } + else + { + // If the descriptor is known at the compilation time, + // use `make_merge_transform_v1_carry_check` because + // it doesn't use memcpy. + return make_merge_transform_v1_carry_check(merge_elems); + } } else { @@ -185,14 +244,23 @@ struct Layout } using Descriptor1dType = - remove_cvref_t; + remove_cvref_t; using DefaultIdxsTupleType = remove_cvref_t; + public: + /** + * \brief Transform descriptor to align to passed indexes. + * + * \param shape Layout shape. + * \param idxs Indexes to align descriptor. + * \param naive_descriptor Descriptor to merge. + * \return Aligned descriptor to idx. + */ template __host__ __device__ constexpr static auto TransformDesc(const Tuple& shape, - const Tuple& idx, - const UnnestedDescriptorType& naive_descriptor) + const Tuple& idxs, + const UnrolledDescriptorType& naive_descriptor) { if constexpr(Tuple::Size() == I1) { @@ -208,19 +276,18 @@ struct Layout static_assert(Tuple::Size() == Tuple::Size(), "Idx rank and Shape rank must be the same (except 1d)."); // Unroll while IdxDims is nested - const auto aligned_shape = AlignShapeToIdx(shape, idx); + const auto aligned_shape = AlignShapeToIdx(shape, idxs); // Transform correct form of shape - return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idx), naive_descriptor); + return CreateMergedDescriptor(aligned_shape, UnrollNestedTuple(idxs), naive_descriptor); } } using MergedNestsDescriptorType = remove_cvref_t; + Shape{}, DefaultIdxsTupleType{}, UnrolledDescriptorType{}))>; - public: __host__ __device__ constexpr auto GetElementSpaceSize() const { - return unnested_descriptor_.GetElementSpaceSize(); + return unrolled_descriptor_.GetElementSpaceSize(); } __host__ __device__ Layout() = delete; @@ -232,16 +299,15 @@ struct Layout * \param unnested_descriptor Descriptor */ __host__ __device__ constexpr Layout(const Shape& shape, - const UnnestedDescriptorType& unnested_descriptor) - : shape_(shape) + const UnrolledDescriptorType& unnested_descriptor) + : unrolled_descriptor_(unnested_descriptor), shape_(shape) { // Construct if runtime mode - if constexpr(!UnnestedDescriptorType::IsKnownAtCompileTime()) + if constexpr(!remove_cvref_t::IsKnownAtCompileTime()) { - unnested_descriptor_ = unnested_descriptor; - descriptor_1d_ = MakeMerge1d(shape_, unnested_descriptor_); + descriptor_1d_ = MakeMerge1d(shape_, unrolled_descriptor_); merged_nests_descriptor_ = - TransformDesc(shape_, DefaultIdxsTupleType{}, unnested_descriptor_); + TransformDesc(shape_, DefaultIdxsTupleType{}, unrolled_descriptor_); } } @@ -254,9 +320,9 @@ struct Layout template __host__ __device__ constexpr index_t operator()() const { - static_assert(UnnestedDescriptorType::IsKnownAtCompileTime(), + static_assert(remove_cvref_t::IsKnownAtCompileTime(), "Compiletime operator used on runtime layout."); - using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnnestedDescriptorType{})); + using TransformedDesc = decltype(TransformDesc(Shape{}, Idxs{}, UnrolledDescriptorType{})); using UnrolledIdx = decltype(UnrollNestedTuple(Idxs{})); return TransformedDesc{}.CalculateOffset(UnrolledIdx{}); } @@ -283,7 +349,7 @@ struct Layout else { // Custom index, need to transform descriptor - const auto transformed_desc = TransformDesc(shape_, Idx, unnested_descriptor_); + const auto transformed_desc = TransformDesc(shape_, Idx, unrolled_descriptor_); return transformed_desc.CalculateOffset(UnrollNestedTuple(Idx)); } } @@ -350,29 +416,55 @@ struct Layout } /** - * \brief Get default descriptor (with the same size as Shape) + * \brief Get descriptor with all nested dimensions merged. + * Example, shape: ((2, 2), 2) + * Descriptor lengths: (4, 2) * - * \return Default descriptor. + * \note The size of merged descriptor is the same as Layout's shape. + * + * \return Merged nests descriptor. */ - __host__ __device__ constexpr const MergedNestsDescriptorType& GetDefaultDescriptor() const + __host__ __device__ constexpr const MergedNestsDescriptorType& + GetMergedNestingDescriptor() const { return merged_nests_descriptor_; } /** - * \brief Get unnested descriptor (with unrolled dims) + * \brief Get descriptor with all dimensions are merged (1D). + * Example, shape: ((2, 2), 2) + * Descriptor lengths: (8) * - * \return Flatten descriptor. + * \return 1D descriptor. */ - __host__ __device__ constexpr const UnnestedDescriptorType& GetUnnestedDescriptor() const + __host__ __device__ constexpr const Descriptor1dType& Get1DDescriptor() const { - return unnested_descriptor_; + return descriptor_1d_; + } + + /** + * \brief Get unnested descriptor (with unrolled dims) + * Example, shape: ((2, 2), 2) + * Descriptor lengths: (2, 2, 2) + * + * \return Flattened descriptor. + */ + __host__ __device__ constexpr const UnrolledDescriptorType& GetUnrolledDescriptor() const + { + return unrolled_descriptor_; } private: - UnnestedDescriptorType unnested_descriptor_; + // All dimensions are unrolled + UnrolledDescriptorType unrolled_descriptor_; + // 1D descriptor Descriptor1dType descriptor_1d_; + // All nesting are merged MergedNestsDescriptorType merged_nests_descriptor_; + // Example, shape: ((2, 2), 2) + // UnrolledDescriptorType lengths: (2, 2, 2) + // Descriptor1dType lengths: (8) + // MergedNestsDescriptorType lengths: (4, 2) const Shape shape_; }; diff --git a/include/ck/wrapper/operations/copy.hpp b/include/ck/wrapper/operations/copy.hpp index aec80f9ca7..7b00fe5500 100644 --- a/include/ck/wrapper/operations/copy.hpp +++ b/include/ck/wrapper/operations/copy.hpp @@ -1,16 +1,21 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "../utils/tensor_utils.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + namespace ck { namespace wrapper { /** - * \brief Perform generic copy between two tensors. Tensors must have the - * same size. + * \brief Perform generic copy between two tensors partitions (threadwise copy). + * Tensors must have the same size. * * \param src_tensor Source tensor. * \param dst_tensor Destination tensor. @@ -37,5 +42,134 @@ __host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& ds } } +/** + * \brief Perform optimized copy between two tensors partitions (threadwise copy). + * Tensors must have the same size. + * + * \tparam DimAccessOrderTuple Tuple with dimension access order. + * \tparam VectorDim Dimension for vectorized read and write. + * \tparam ScalarPerVector Number of scalar per vectorized read and write. + * \param src_tensor Source tensor. + * \param dst_tensor Destination tensor. + */ +template +__device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) +{ + static_assert(is_detected::value); + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const auto& in_grid_desc = layout(src_tensor).GetUnrolledDescriptor(); + const auto& out_grid_desc = layout(dst_tensor).GetUnrolledDescriptor(); + + using SrcShapeType = remove_cvref_t; + constexpr index_t num_dims = SrcShapeType::Size(); + + constexpr auto thread_slice_lengths = + generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number{}); + constexpr auto dim_access_order = generate_sequence_v2( + [](auto I) { return DimAccessOrderTuple{}.At(I); }, Number{}); + + if constexpr(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer) + { + // Perform a copy between DynamicBuffers + auto transfer = ThreadwiseTensorSliceTransfer_v7< + Tuple, + Tuple, + decltype(tie(in_grid_desc)), + decltype(tie(out_grid_desc)), + tensor_operation::element_wise::PassThrough, + Sequence(InMemoryDataOperationEnum::Set)>, + decltype(thread_slice_lengths), + decltype(dim_access_order), + VectorDim, + ScalarPerVector, + Sequence, + Sequence>{in_grid_desc, + make_tuple(src_tensor.GetMultiIdxOffsets()), + out_grid_desc, + make_tuple(dst_tensor.GetMultiIdxOffsets()), + tensor_operation::element_wise::PassThrough{}}; + + transfer.Run(tie(in_grid_desc), + tie(src_tensor.GetBuffer()), + tie(out_grid_desc), + tie(dst_tensor.GetBuffer())); + } + else if constexpr(!SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer) + { + // Perform copy from StaticBuffer to DynamicBuffer + const auto src_slice_origin_idxs = + generate_tuple([&](auto) { return I0; }, Number{}); + + auto transfer = + ThreadwiseTensorSliceTransfer_v1r3, + remove_cvref_t, + tensor_operation::element_wise::PassThrough, + decltype(thread_slice_lengths), + decltype(dim_access_order), + VectorDim, + ScalarPerVector, + InMemoryDataOperationEnum::Set, + I1, + true>{out_grid_desc, + dst_tensor.GetMultiIdxOffsets(), + tensor_operation::element_wise::PassThrough{}}; + + transfer.Run(in_grid_desc, + src_slice_origin_idxs, + src_tensor.GetBuffer(), + out_grid_desc, + dst_tensor.GetBuffer()); + } + else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer) + { + // Perform copy from DynamicBuffer to StaticBuffer + const auto src_dst_slice_origin = + generate_tuple([&](auto) { return I0; }, Number{}); + constexpr auto src_vector_tensor_lengths = generate_sequence_v2( + [&](auto I) { + if constexpr(I == VectorDim) + { + return Number{}; + } + else + { + return I1; + } + }, + Number{}); + + auto transfer = + ThreadwiseTensorSliceTransfer_v4r1, + remove_cvref_t, + decltype(thread_slice_lengths), + decltype(dim_access_order), + decltype(src_vector_tensor_lengths), + decltype(dim_access_order)>{ + src_tensor.GetMultiIdxOffsets()}; + + transfer.Run(in_grid_desc, + src_dst_slice_origin, + src_tensor.GetBuffer(), + out_grid_desc, + src_dst_slice_origin, + dst_tensor.GetBuffer()); + } + else + { + // Perform copy between StaticBuffers + copy(src_tensor, dst_tensor); + } +} + } // namespace wrapper } // namespace ck diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index a363641373..57d79c5940 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -10,189 +10,205 @@ namespace ck { namespace wrapper { +namespace detail { +namespace { +/** + * \brief Check if Tuple contains Slice object + * + * \return True if tuple contains Slice object. + */ +template +__host__ __device__ constexpr bool HasSlice(T&&) +{ + return is_detected::value; +} +template +__host__ __device__ constexpr bool HasSlice(Tuple&&) +{ + return (HasSlice(Ts{}) || ...); +} + +/** + * \brief Calculate new shape after slice from parent shape. + * + * \param idxs Tuple of indexes defining slice ranges. + * \param shape Shape which will be sliced. + * \return New tensor shape. + */ +template +__host__ __device__ constexpr auto GetSlicedShape(const Tuple& idxs, + const SlicedShape& shape) +{ + // Pack each value in tuple to remove empty tuples after generation + auto new_shape = generate_tuple( + [&](auto i) { + constexpr auto num_i = Number{}; + if constexpr(is_detected>>::value) + { + if constexpr(!detail::HasSlice(tuple_element_t>{})) + { + // if tuple does not have any slice then we can remove dimension + return Tuple<>{}; + } + else + { + // if tuple then recurrence + return make_tuple(GetSlicedShape(idxs.At(num_i), shape.At(num_i))); + } + } + else if constexpr(is_detected>>::value) + { + // calculate new dimension + const auto& dim = size(shape.At(num_i)); + const auto val = idxs.At(num_i).range(dim); + return make_tuple(val); + } + else + { + // remove dimension for just value + return Tuple<>{}; + } + }, + Number::Size()>{}); + // Remove empty tuples (deleted elements) and return + return UnrollNestedTuple<0, 1>(new_shape); +} + +/** + * \brief Generate Freeze for each of nested shape. + * + * \param idx Tuple of start indices for slice. + * \param shape Shape which will be freezed. + * \return Generated freeze transforms. + */ +template +__host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, const Shape& shape) +{ + const auto unrolled_shape = UnrollNestedTuple(shape); + return generate_tuple( + [&](auto i) { + // dimension offset from idx + const auto dim = unrolled_shape.At(Number{}); + const auto dim_idx = idx % dim; + idx /= dim; + return make_freeze_transform(dim_idx); + }, + Number{}); +} + +/** + * \brief Generate transforms for slice tensor. + * + * \param idx Tuple of start indices for slice. + * \param shape Shape which will be sliced. + * \return Generated transforms. + */ +template +__host__ __device__ constexpr auto GenerateSliceTransforms(const Tuple& idx, + const Shape& shape) +{ + // Pack each value in tuple to remove empty tuples after generation + auto transforms = generate_tuple( + [&](auto i) { + constexpr auto num_i = Number{}; + if constexpr(is_detected>>::value) + { + return GenerateSliceTransforms(idx.At(num_i), shape.At(num_i)); + } + else if constexpr(is_detected>>::value) + { + + const auto from = idx.At(num_i).from_; + const auto dim = size(shape); + const auto range = idx.At(num_i).range(dim); + return make_slice_transform(range, from, from + range); + } + else + { + // remove dimension for just value + return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i)); + } + }, + Number::Size()>{}); + // Remove empty tuples (deleted elements) and return + return UnrollNestedTuple(transforms); +} + +template +__host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze&) +{ + // There is no output for Freeze transform + return Sequence<>{}; +} + +template +__host__ __device__ constexpr auto GetSequenceVal(const ck::Slice&) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&) +{ + return Tuple<>{}; +} + +template +__host__ __device__ constexpr auto GenerateUpperDims(const Tuple& transforms) +{ + constexpr auto num_transforms = Tuple::Size(); + // Deduce Sequence element for specific transform + const auto current_elem = GetSequenceVal(transforms.At(Number<0>{})); + if constexpr(is_same_v>) + { + const auto next_tuple = GenerateUpperDims(TupleSlice<1, num_transforms>(transforms)); + return concat_tuple(make_tuple(current_elem), next_tuple); + } + else + { + // Increase i if current_elem is Slice transform + const auto next_tuple = GenerateUpperDims(TupleSlice<1, num_transforms>(transforms)); + return concat_tuple(make_tuple(current_elem), next_tuple); + } +} + +template +__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple& idx, + const Shape& shape, + const FlattenDescriptor& flatten_desc) +{ + constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size(); + + const auto transforms = GenerateSliceTransforms(idx, shape); + using TransformsTupleType = decltype(transforms); + + const auto lower_dims = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){}; + return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); +} +} // namespace +} // namespace detail + /** * \brief Tensor wrapper that performs static and dynamic buffer logic. + * The tensor is based on a descriptor stored in the Layout. Additionally, + * tensor can be sliced or shifted using multi-index offset. * * \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR). * \tparam ElementType Element data type. * \tparam Shape Tensor shape (layout component). - * \tparam UnnestedDescriptorType Unnested descriptor (layout component). - * \tparam NumVectors Number of vectors (only for VGPR, SGPR). - * \tparam ScalarPerVector Scalars per vector (only for VGPR, SGPR). + * \tparam UnrolledDescriptorType Flatten descriptor (layout component). */ template + typename UnrolledDescriptorType> struct Tensor { - private: - // Check if Tuple contains Slice object - template - __host__ __device__ constexpr static bool IsSlicing(T&&) - { - return is_detected::value; - } - template - __host__ __device__ constexpr static bool IsSlicing(Tuple&&) - { - return (IsSlicing(Ts{}) || ...); - } - - // Calculate new tensor shape after slice - template - __host__ __device__ constexpr auto GetShapeFromSlicedTensor(const Tuple& idx, - const ShapeTmpType& shape) const - { - // Pack each value in tuple to remove empty tuples after generation - auto new_shape = generate_tuple( - [&](auto i) { - constexpr auto num_i = Number{}; - if constexpr(is_detected>>::value) - { - if constexpr(!IsSlicing(tuple_element_t>{})) - { - // if tuple does not have any slice then we can remove dimension - return Tuple<>{}; - } - else - { - // if tuple then recurrence - return make_tuple(GetShapeFromSlicedTensor(idx.At(num_i), shape.At(num_i))); - } - } - else if constexpr(is_detected>>::value) - { - // calculate new dimension - const auto& dim = size(shape.At(num_i)); - const auto val = idx.At(num_i).range(dim); - return make_tuple(val); - } - else - { - // remove dimension for just value - return Tuple<>{}; - } - }, - Number::Size()>{}); - // Remove empty tuples (deleted elements) and return - return UnrollNestedTuple<0, 1>(new_shape); - } - - // Generate Freeze for each of nested shape - template - __host__ __device__ constexpr auto GenerateMultipleFreeze(T idx, - const ShapeTmpType& shape) const - { - const auto unrolled_shape = UnrollNestedTuple(shape); - return generate_tuple( - [&](auto i) { - // dimension offset from idx - const auto dim = unrolled_shape.At(Number{}); - const auto dim_idx = idx % dim; - idx /= dim; - return make_freeze_transform(dim_idx); - }, - Number{}); - } - - template - __host__ __device__ constexpr auto - GetTransformsFromSlicedTensor(const Tuple& idx, const ShapeTmpType& shape) const - { - // Pack each value in tuple to remove empty tuples after generation - auto transforms = generate_tuple( - [&](auto i) { - constexpr auto num_i = Number{}; - if constexpr(is_detected>>::value) - { - return GetTransformsFromSlicedTensor(idx.At(num_i), shape.At(num_i)); - } - else if constexpr(is_detected>>::value) - { - - const auto from = idx.At(num_i).from_; - const auto dim = shape.At(num_i); - const auto range = idx.At(num_i).range(dim); - return make_slice_transform(range, from, from + range); - } - else - { - // remove dimension for just value - return GenerateMultipleFreeze(idx.At(num_i), shape.At(num_i)); - } - }, - Number::Size()>{}); - // Remove empty tuples (deleted elements) and return - return UnrollNestedTuple(transforms); - } - - // There is no output for Freeze transform - template - __host__ __device__ constexpr auto GetSequenceVal(const ck::Freeze&) const - { - return Sequence<>{}; - } - - template - __host__ __device__ constexpr auto - GetSequenceVal(const ck::Slice&) const - { - return Sequence{}; - } - - template - __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<>&) const - { - return Tuple<>{}; - } - - template - __host__ __device__ constexpr auto - GenerateUpperDims(const Tuple& transforms) const - { - constexpr auto num_transforms = Tuple::Size(); - // Deduce Sequence element for specific transform - const auto currect_elem = GetSequenceVal(transforms.At(Number<0>{})); - if constexpr(is_same_v>) - { - const auto next_tuple = GenerateUpperDims(TupleSlice<1, num_transforms>(transforms)); - return concat_tuple(make_tuple(currect_elem), next_tuple); - } - else - { - // Increase i if current_elem is Slice transform - const auto next_tuple = - GenerateUpperDims(TupleSlice<1, num_transforms>(transforms)); - return concat_tuple(make_tuple(currect_elem), next_tuple); - } - } - - template - __host__ __device__ constexpr auto - GetDescriptorFromSlicedTensor(const Tuple& idx, - const ShapeTmpType& shape, - const FlattenDescriptor& flatten_desc) const - { - constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size(); - - const auto transforms = GetTransformsFromSlicedTensor(idx, shape); - using TransformsTupleType = decltype(transforms); - - const auto lower_dims = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); - const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){}; - return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); - } - public: - using ElementSpaceSize = decltype(Layout{ - Shape{}, UnnestedDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer + using ElementSpaceSize = decltype(Layout{ + Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer using TensorElementType = ElementType; // DataType static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace; @@ -200,134 +216,207 @@ struct Tensor BufferAddressSpace == MemoryTypeEnum ::Vgpr); __host__ __device__ Tensor() = delete; - __host__ __device__ Tensor(ElementType* pointer, - const Layout& layout) + __host__ __device__ constexpr Tensor(ElementType* pointer, + const Layout& layout) : layout_(layout), - buffer_(make_dynamic_buffer(pointer, layout.GetElementSpaceSize())) + buffer_(make_dynamic_buffer(pointer, layout.GetElementSpaceSize())), + multi_idx_offset_(make_zero_multi_index()), + base_offset_(0) { + static_assert(IsDynamicBuffer, "Wrong BufferAddressSpace for register."); } - __host__ __device__ Tensor(const Layout& layout) - : layout_(layout) + __host__ __device__ constexpr Tensor(const Layout& layout) + : layout_(layout), + multi_idx_offset_(make_zero_multi_index()), + base_offset_(0) { static_assert(!IsDynamicBuffer, "Wrong BufferAddressSpace for register."); } - __host__ __device__ constexpr const Layout& GetLayout() const + __host__ __device__ constexpr const Layout& GetLayout() const { return layout_; } - // Getter for new sliced tensor - template {}), bool> = false> - __host__ __device__ auto operator[](const Tuple& idx) const + /** + * \brief Get the new sliced tensor. + * + * \param idx Tuple of indices: slice(from,to) or scalar. + * \return Sliced tensor. + */ + template {}), bool> = false> + __host__ __device__ auto operator[](const Tuple& idx) { static_assert(IsDynamicBuffer, "Register slice is not supported"); const auto& shape = layout_.GetShape(); - auto new_shape = GetShapeFromSlicedTensor(idx, shape); + auto new_shape = detail::GetSlicedShape(idx, shape); - const auto& flatten_desc = layout_.GetUnnestedDescriptor(); - auto new_desc = GetDescriptorFromSlicedTensor(idx, shape, flatten_desc); + const auto& flatten_desc = layout_.GetUnrolledDescriptor(); + auto new_desc = detail::GenerateSlicedDescriptor(idx, shape, flatten_desc); const auto new_layout = Layout(new_shape, new_desc); + // Update embed offset + base_offset_ -= new_layout(make_tuple(Number<0>{})); return make_tensor(buffer_.p_data_, new_layout); } - template {}), bool> = false> - __host__ __device__ auto operator()(const Tuple& idx) const + template {}), bool> = false> + __host__ __device__ auto operator()(const Tuple& idx) { return this->operator[](idx); } - template {}), bool> = false> - __host__ __device__ auto operator()(Idxs... idxs) const + template {}), bool> = false> + __host__ __device__ auto operator()(Idxs... idxs) { return this->operator[](make_tuple(idxs...)); } - // Getter for the const value - template {}), bool> = false> + /** + * \brief Getter of the tensor's const value reference. + * + * \param idx Tuple of indices. + * \return Requested value. + */ + template {}), bool> = false> __host__ __device__ const ElementType& operator[](const Tuple& idx) const { if constexpr(IsDynamicBuffer) { - const index_t offset = layout_(idx); + const index_t offset = layout_(idx) + base_offset_; return buffer_[offset]; } else { - constexpr index_t offset = Layout{ + constexpr index_t index_offset = Layout{ Shape{}, - UnnestedDescriptorType{}}.template operator()>(); - return buffer_[Number{}]; + UnrolledDescriptorType{}}.template operator()>(); + // Calculate and apply base offset in compile-time + constexpr index_t base_offset = Layout{ + Shape{}, + UnrolledDescriptorType{}}.template operator()>(); + return buffer_[Number{}]; } } - template {}), bool> = false> + template {}), bool> = false> __host__ __device__ const ElementType& operator()(const Tuple& idx) const { return this->operator[](idx); } - template {}), bool> = false> + template {}), bool> = false> __host__ __device__ const ElementType& operator()(Idxs... idxs) const { return this->operator[](make_tuple(idxs...)); } - // Getter for the value reference - template {}), bool> = false> + /** + * \brief Getter of tensor value reference. + * + * \param idx Tuple of indices. + * \return Requested value. + */ + template {}), bool> = false> __host__ __device__ ElementType& operator[](const Tuple& idx) { if constexpr(IsDynamicBuffer) { - const index_t offset = layout_(idx); + const index_t offset = layout_(idx) + base_offset_; return buffer_(offset); } else { - constexpr index_t offset = Layout{ + constexpr index_t index_offset = Layout{ Shape{}, - UnnestedDescriptorType{}}.template operator()>(); - return buffer_(Number{}); + UnrolledDescriptorType{}}.template operator()>(); + // Apply embed offset (calculate in compiletime) + constexpr index_t base_offset = Layout{ + Shape{}, + UnrolledDescriptorType{}}.template operator()>(); + return buffer_(Number{}); } } - template {}), bool> = false> + template {}), bool> = false> __host__ __device__ ElementType& operator()(const Tuple& idx) { return this->operator[](idx); } - template {}), bool> = false> + template {}), bool> = false> __host__ __device__ ElementType& operator()(Idxs... idxs) { return this->operator[](make_tuple(idxs...)); } - __host__ __device__ constexpr auto GetDefaultDescriptor() + /** + * \brief Get descriptor with all nested dimensions merged. + * + * \return Merged nests descriptor. + */ + __host__ __device__ constexpr auto GetMergedNestingDescriptor() { - return layout_.GetDefaultDescriptor(); + return layout_.GetMergedNestingDescriptor(); } + /** + * \brief Get pointer to the data. + * + * \return Pointer. + */ __host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; } + __host__ __device__ constexpr auto& GetBuffer() { return buffer_; } + __host__ __device__ constexpr auto& GetBuffer() const { return buffer_; } + + /** + * \brief Get multi index offset to the data. + * + * \return Multi index offset. + */ + __host__ __device__ constexpr auto& GetMultiIdxOffsets() const { return multi_idx_offset_; } + + /** + * \brief Apply multi index offset on the tensor. + * + * \param multi_idx_offset Multi index offset. + */ + template + __host__ __device__ constexpr void SetMultiIdxOffset(const MultiIdxOffsets multi_idx_offset) + { + multi_idx_offset_ = multi_idx_offset; + base_offset_ += layout_(multi_idx_offset); + } + private: using DynamicBufferType = DynamicBuffer; - using StaticBufferType = - StaticBufferTupleOfVector; + using StaticBufferType = StaticBuffer; // If register use static buffer, else use dynamic buffer using Buffer = std::conditional_t; - const Layout layout_; + const Layout layout_; Buffer buffer_; + // We use multi_idx_offset_ to enable the creation of a descriptor in + // compile time for partitions or tiles if tile shape and thread layout + // is known at compile time (We can use the same descriptor for each + // thread). Additionally, the copy between the static and dynamic buffer + // requires a descriptor known at compile time, so we can shift data using + // such multi_idx_offset_. + MultiIndex multi_idx_offset_; + // Base offset and multi index offset are corresponding to exactly the + // same element in tensor ( and in physical memory ). Multi index offset + // is multi dimensional index. However base offset is calculated using + // tensor descriptor (thus all it's transforms) and is linear (1D). + // We store base_offset_ to avoid multiple recalculations. + index_t base_offset_; }; } // namespace wrapper diff --git a/include/ck/wrapper/utils/layout_utils.hpp b/include/ck/wrapper/utils/layout_utils.hpp index f4ba0a969f..d04bd5078b 100644 --- a/include/ck/wrapper/utils/layout_utils.hpp +++ b/include/ck/wrapper/utils/layout_utils.hpp @@ -22,14 +22,19 @@ namespace wrapper { // Disable from doxygen docs generation /// @cond // forward declaration -template +template struct Layout; template using is_tuple = decltype(std::declval().IsTuple()); namespace { -// Generate packed (column-major) strides if not passed +/** + * \brief Generate packed (column-major) strides if not passed + * + * \param shape Tensor shape. + * \return Generated column-major strides. + */ template __host__ __device__ constexpr static auto GenerateColumnMajorPackedStrides(const Tuple& shape) @@ -50,9 +55,16 @@ GenerateColumnMajorPackedStrides(const Tuple& shape) Number{}); } +/** + * \brief Create naive tensor descriptor from nested shape. + * + * \param shape Tensor shape. + * \param strides Tensor strides. + * \return Unrolled descriptor + */ template -__host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shape, - const LayoutStrides& strides) +__host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& shape, + const LayoutStrides& strides) { const auto unrolled_shape = UnrollNestedTuple(shape); if constexpr(is_same_v>) @@ -86,8 +98,8 @@ __host__ __device__ constexpr auto MakeFlattenDescriptor(const LayoutShape& shap template __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides) { - using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Strides{})); - return Layout(shape, MakeFlattenDescriptor(shape, strides)); + using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{})); + return Layout(shape, MakeUnrolledDescriptor(shape, strides)); } /** @@ -100,15 +112,19 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides template __host__ __device__ constexpr auto make_layout(const Shape& shape) { - using UnnestedDescriptorType = decltype(MakeFlattenDescriptor(Shape{}, Tuple<>{})); - return Layout(shape, MakeFlattenDescriptor(shape, Tuple<>{})); + using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{})); + return Layout(shape, MakeUnrolledDescriptor(shape, Tuple<>{})); } // Layout helpers // get -// Get dim (could be returned from get with empty Idxs) + /** * \private + * \brief Get dim. + * + * \param dim Dimension. + * \return Returned the same dimension. */ template __host__ __device__ T constexpr get(const T& dim) @@ -178,7 +194,7 @@ __host__ __device__ constexpr auto get(const Layout& layout) }, Number{}); - const auto& flatten_desc = layout.GetUnnestedDescriptor(); + const auto& flatten_desc = layout.GetUnrolledDescriptor(); auto new_desc = transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); return Layout(new_shape, new_desc); } @@ -197,9 +213,12 @@ __host__ __device__ constexpr auto get(const T& elem) } // size -// Get dim size (could be returned from get function) /** * \private + * \brief Get size. + * + * \param dim Size. + * \return Returned the same size. */ template __host__ __device__ T constexpr size(const T& dim) @@ -214,8 +233,8 @@ __host__ __device__ T constexpr size(const T& dim) * \param layout Layout to get Shape of. * \return Requsted length. */ -template -__host__ __device__ constexpr auto size(const Layout& layout) +template +__host__ __device__ constexpr auto size(const Layout& layout) { return layout.template GetLength(); } @@ -240,8 +259,8 @@ __host__ __device__ constexpr auto size(const Tuple& shape) * \param layout Layout to calculate shape size. * \return Requsted size. */ -template -__host__ __device__ constexpr auto size(const Layout& layout) +template +__host__ __device__ constexpr auto size(const Layout& layout) { return layout.GetLengths(); } @@ -280,9 +299,9 @@ __host__ __device__ constexpr auto size(const T& elem) * \param layout Layout to calculate rank. * \return Requsted rank. */ -template +template __host__ __device__ constexpr auto -rank([[maybe_unused]] const Layout& layout) +rank([[maybe_unused]] const Layout& layout) { return Shape::Size(); } @@ -302,17 +321,25 @@ __host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple& t /** * \private + * \brief Rank for scalar + * + * \param dim Dimension scalar. + * \return Returned 1. */ template -__host__ __device__ constexpr index_t rank(const Number&) +__host__ __device__ constexpr index_t rank([[maybe_unused]] const Number& dim) { return 1; } /** * \private + * \brief Rank for scalar + * + * \param dim Dimension scalar. + * \return Returned 1. */ -__host__ __device__ constexpr index_t rank(const index_t&) { return 1; } +__host__ __device__ constexpr index_t rank([[maybe_unused]] const index_t& dim) { return 1; } /** * \brief Hierarchical rank. @@ -334,8 +361,8 @@ __host__ __device__ constexpr auto rank(const T& elem) * \param layout Layout to calculate depth. * \return Requsted depth. */ -template -__host__ __device__ constexpr auto depth(const Layout& layout) +template +__host__ __device__ constexpr auto depth(const Layout& layout) { const auto& shape = layout.GetShape(); return TupleDepth(shape); @@ -355,17 +382,25 @@ __host__ __device__ constexpr auto depth(const Tuple& tuple) /** * \private + * \brief Depth for scalar + * + * \param dim Scalar. + * \return Returned 0. */ template -__host__ __device__ constexpr index_t depth(const Number&) +__host__ __device__ constexpr index_t depth([[maybe_unused]] const Number& dim) { return 0; } /** * \private + * \brief Depth for scalar + * + * \param dim Scalar. + * \return Returned 0. */ -__host__ __device__ constexpr index_t depth(const index_t&) { return 0; } +__host__ __device__ constexpr index_t depth([[maybe_unused]] const index_t& dim) { return 0; } /** * \brief Hierarchical depth. diff --git a/include/ck/wrapper/utils/tensor_partition.hpp b/include/ck/wrapper/utils/tensor_partition.hpp index a0634f6b38..6aae5a92fe 100644 --- a/include/ck/wrapper/utils/tensor_partition.hpp +++ b/include/ck/wrapper/utils/tensor_partition.hpp @@ -6,12 +6,22 @@ #include "tensor_utils.hpp" #include "layout_utils.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" + namespace ck { namespace wrapper { namespace { -// Calculate shape for partition based on number of threads per each dim and -// previous shape + +/** + * \brief Calculate shape for partition based on number of threads per each dim and + * previous shape + * + * \param shape Base tensor shape. + * \param thread_lengths Tuple of thread lengths. + * \return Partition shape. + */ template __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple& shape, const Tuple& thread_lengths) @@ -20,265 +30,165 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple{}; - if constexpr(is_detected>>::value) - { - // if tuple then recurrence - return CalculateLocalPartitionShape(shape.At(num_i), thread_lengths.At(num_i)); - } - else - { - const auto slice_len = shape.At(num_i) / thread_lengths.At(num_i); - return slice_len; - } - }, - Number::Size()>{}); -} - -// Calculate shape for partition based on number of threads per each dim, -// previous strides and steps -template -__host__ __device__ constexpr auto -CalculateLocalPartitionDescriptor(const Tuple& shape, - const Tuple& thread_lengths, - const Tuple& steps, - const FlattenDescType& flatten_desc) -{ - - static_assert(Tuple::Size() == Tuple::Size(), "Wrong thread_lengths shape."); - const auto unrolled_thread_lengths = UnrollNestedTuple(thread_lengths); - const auto unrolled_shape = UnrollNestedTuple(shape); - constexpr auto dims = decltype(unrolled_thread_lengths)::Size(); - - using UnrolledStepsType = decltype(UnrollNestedTuple(steps)); - - using I1 = Number<1>; - - const auto transforms = generate_tuple( - [&](auto i) { - constexpr auto num_i = Number{}; - if constexpr(is_same_v, Tuple<>>) - { - // By default raked partition - const auto partition_stride = unrolled_thread_lengths.At(num_i); - return make_embed_transform(make_tuple(unrolled_shape.At(num_i)), - make_tuple(partition_stride)); - } - else if constexpr(!is_same_v, index_t>) - { - // Compiletime partition - if constexpr(is_same_v, I1>) - { - // raked - const auto partition_stride = unrolled_thread_lengths.At(num_i); - return make_embed_transform(make_tuple(unrolled_shape.At(num_i)), - make_tuple(partition_stride)); - } - else - { - // packed - return make_embed_transform(make_tuple(unrolled_shape.At(num_i)), - make_tuple(I1{})); - } - } - else - { - // Runtime partition - if(steps.At(num_i) == 1) - { - // raked - const auto partition_stride = unrolled_thread_lengths.At(num_i); - return make_embed_transform(make_tuple(unrolled_shape.At(num_i)), - make_tuple(partition_stride)); - } - else - { - // packed - return make_embed_transform(make_tuple(unrolled_shape.At(num_i)), - make_tuple(I1{})); - } - } - }, - Number{}); - - const auto lower_dims = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); - const auto upper_dims = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); - return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); -} - -template -__host__ __device__ constexpr auto CalculateLayoutOffsetIdxImpl(const Tuple& thread_lengths, - const Tuple& steps, - index_t& thread_id) -{ - return generate_tuple( - [&](auto i) { - constexpr auto num_i = Number{}; - if constexpr(is_detected>>::value) - { - // if tuple then recurrence - if constexpr(is_same_v, Tuple<>>) - { - return CalculateLayoutOffsetIdxImpl( - thread_lengths.At(num_i), Tuple<>{}, thread_id); - } - else - { - return CalculateLayoutOffsetIdxImpl( - thread_lengths.At(num_i), steps.At(num_i), thread_id); - } - } - else - { - // Update thread_id after each dim - const auto dim_thread_id = thread_id % thread_lengths.At(num_i); - thread_id /= thread_lengths.At(num_i); - if constexpr(is_same_v, Tuple<>>) - { - return dim_thread_id; - } - else - { - // Apply step - return steps.At(num_i) * dim_thread_id; - } - } + const auto slice_len = size(shape) / thread_lengths.At(num_i); + return slice_len; }, Number::Size()>{}); } -// Convert integer thread_idx to tuple index with steps applied -template -__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple& thread_lengths, - const Tuple& steps, - const index_t thread_id) +/** + * \brief Calculate total number of blocks. + * + * \param shape Base tensor shape. + * \param tile_shape Tile shape. + * \return Tuple with blocks number. + */ +template +__host__ __device__ constexpr auto CalculateGridSize(const Tuple& shape, + const Tuple& tile_shape) { - // Create tmp thread_id copy for CalculateLayoutOffsetIdxImpl updates - index_t thread_id_copy = thread_id; - return CalculateLayoutOffsetIdxImpl(thread_lengths, steps, thread_id_copy); + static_assert(Tuple::Size() == Tuple::Size(), "Wrong thread_lengths shape."); + return generate_tuple([&](auto i) { return size(shape) / size(tile_shape); }, + Number::Size()>{}); } -// Apply steps to index represented as tuple -template -__host__ __device__ constexpr auto CalculateLayoutOffsetIdx(const Tuple& steps, - const Tuple& block_idxs) +/** + * \brief Calculate scaled offset for new partition/tile. + * + * \param thread_idxs Thread 1d id. + * \param partition_lengths_seq Sequence of partition shape. + * \param old_offset_idxs Multi index offset from base tensor to shift values. + * \return Partition shape. + */ +template +__host__ __device__ constexpr auto +CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs, + const PartitionLengthsSeq& partition_lengths_seq, + const OldOffsetIdxs& old_offset_idxs) { - return generate_tuple( - [&](auto i) { - constexpr auto num_i = Number{}; - if constexpr(is_detected>>::value) - { - // if tuple then recurrence - if constexpr(is_same_v, Tuple<>>) - { - return CalculateLayoutOffsetIdx(Tuple<>{}, block_idxs.At(num_i)); - } - else - { - return CalculateLayoutOffsetIdx(steps.At(num_i), block_idxs.At(num_i)); - } - } - else - { - if constexpr(is_same_v, Tuple<>>) - { - return block_idxs.At(num_i); - } - else - { - // apply step - return steps.At(num_i) * block_idxs.At(num_i); - } - } - }, - Number::Size()>{}); + return thread_idxs * partition_lengths_seq + old_offset_idxs; } -// User passes only shape per block to the make_local_tile function. This function calculates -// block layout based on the shape. -template -__host__ __device__ constexpr auto CalculateBlockLengths(const Tuple& shape, - const Tuple& tile_shape) -{ - return generate_tuple( - [&](auto i) { - constexpr auto num_i = Number{}; - if constexpr(is_detected>>::value) - { - // if tuple then recurrence - return CalculateBlockLengths(shape.At(num_i), tile_shape.At(num_i)); - } - else - { - return shape.At(num_i) / tile_shape.At(num_i); - } - }, - Number::Size()>{}); -} } // namespace /** - * \brief Create local partition for thread. + * \brief Create local partition for thread (At now only packed partition + * is supported). * * \param tensor Tensor for partition. - * \param thread_lengths Layout of threads. + * \param thread_lengths Layout of threads (could not be nested). * \param thread_id Thread index represented as integer. - * \param steps Thread step (default=1, raked partition) * \return Partition tensor. */ -template > -__host__ __device__ constexpr auto make_local_partition(const TensorType& tensor, - const ThreadLengthsTuple& thread_lengths, - const index_t thread_id, - const StepsTuple steps = StepsTuple{}) +template +__host__ __device__ constexpr auto +make_local_partition(TensorType& tensor, + [[maybe_unused]] const ThreadLengthsTuple& thread_lengths, + const index_t thread_id) { - // Create shape, strides and layout for new partition tensor - const auto partition_shape = CalculateLocalPartitionShape(shape(tensor), thread_lengths); - // Create new descriptor and layout - const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor(); - auto partition_desc = - CalculateLocalPartitionDescriptor(shape(tensor), thread_lengths, steps, flatten_desc); - const auto partition_layout = Layout( - partition_shape, partition_desc); - // Calculate offset for new partition tensor - const auto offset_idx = CalculateLayoutOffsetIdx(thread_lengths, steps, thread_id); - const auto partition_offset = layout(tensor)(offset_idx); - return make_tensor(tensor.GetPointer() + partition_offset, - partition_layout); + static_assert(!IsNestedTuple(ThreadLengthsTuple{})); + // Calculate new partition shape + const auto& tensor_shape = shape(tensor); + constexpr auto partition_shape = + CalculateLocalPartitionShape(decltype(tensor_shape){}, ThreadLengthsTuple{}); + // Create Thread Cluster Descriptor + constexpr auto partition_lengths_seq = generate_sequence_v2( + [&](auto I) { return size(partition_shape); }, Number{}); + constexpr auto thread_lengths_seq = + generate_sequence_v2([&](auto I) { return size(ThreadLengthsTuple{}); }, + Number{}); + constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq); + // Calculate thread idxs and offsets + const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id)); + const auto offset_multi_idxs = + CalculateOffsetMultiIdxs(thread_idxs, partition_lengths_seq, tensor.GetMultiIdxOffsets()); + // Create new layout and tensor + auto& flatten_desc = layout(tensor).GetUnrolledDescriptor(); + const auto partition_layout = + Layout, decltype(flatten_desc)>( + partition_shape, flatten_desc); + auto partition_tensor = + make_tensor(tensor.GetPointer(), partition_layout); + // Apply offsets + partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs)); + return partition_tensor; } /** - * \brief Create local tile for thread block. + * \brief Create local tile for thread block. (At now only packed tile + * is supported). + * + * \note Temporary to gain the best performance use 2d + * tile_shape. + * * * \param tensor Tensor for partition. * \param tile_shape Shapes of requested tile. - * \param block_idx Block index represented as tuple. - * \param steps Block step (default=1, raked partition) + * \param block_id Block index represented as integer. + * \return Tile tensor. */ -template > -__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, - const BlockShapeTuple& tile_shape, - const BlockIdxTuple& block_idx, - const StepsTuple steps = StepsTuple{}) +template +__host__ __device__ constexpr auto +make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id) { - // Create block lengths, strides and layout for new tile tensor - const auto block_lengths = CalculateBlockLengths(shape(tensor), tile_shape); - // Create new descriptor and layout - const auto& flatten_desc = layout(tensor).GetUnnestedDescriptor(); - auto tile_desc = - CalculateLocalPartitionDescriptor(tile_shape, block_lengths, steps, flatten_desc); - const auto tile_layout = Layout, decltype(tile_desc)>( - tile_shape, tile_desc); - // Calculate offset for new partition tensor - const auto offset_idx = CalculateLayoutOffsetIdx(steps, block_idx); - const auto tile_offset = layout(tensor)(offset_idx); - return make_tensor(tensor.GetPointer() + tile_offset, - tile_layout); + static_assert(!IsNestedTuple(BlockShapeTuple{})); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor(); + + if constexpr(BlockShapeTuple::Size() == I2) + { + // Optimized version for 2d tile shape [MxK] + const auto block_2_tile_map = + BlockToCTileMap_M00_N0_M01Adapt>(aligned_desc); + const auto block_work_idx = + block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id)); + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * size<0>(tile_shape)); + const index_t k_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * size<1>(tile_shape)); + const auto offset_multi_idxs = + make_tuple(m_block_data_idx_on_grid, k_block_data_idx_on_grid); + // Create new layout and tensor + const auto tile_layout = + Layout, decltype(aligned_desc)>(tile_shape, + aligned_desc); + auto tile_tensor = + make_tensor(tensor.GetPointer(), tile_layout); + // Apply offsets + tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs)); + return tile_tensor; + } + else + { + // Calculate offsets + // Sequence with data to process per block + constexpr auto tile_shape_seq = + generate_sequence_v2([](auto I) { return size(BlockShapeTuple{}.At(I)); }, + Number{}); + // Tuple with number of blocks + const auto block_lengths = CalculateGridSize(shape(tensor), tile_shape); + constexpr auto block_cluster_desc_ = make_cluster_descriptor(block_lengths); + const auto block_idxs = + block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id)); + const auto offset_multi_idxs = + CalculateOffsetMultiIdxs(block_idxs, tile_shape_seq, tensor.GetMultiIdxOffsets()); + // Create new layout and tensor + const auto tile_layout = + Layout, decltype(aligned_desc)>(tile_shape, + aligned_desc); + auto tile_tensor = + make_tensor(tensor.GetPointer(), tile_layout); + // Apply offsets + tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs)); + return tile_tensor; + } } } // namespace wrapper diff --git a/include/ck/wrapper/utils/tensor_utils.hpp b/include/ck/wrapper/utils/tensor_utils.hpp index 1e932e62e1..7ec080760a 100644 --- a/include/ck/wrapper/utils/tensor_utils.hpp +++ b/include/ck/wrapper/utils/tensor_utils.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -10,6 +10,7 @@ #include "ck/utility/tuple_helper.hpp" #include "ck/utility/dynamic_buffer.hpp" #include "ck/utility/amd_address_space.hpp" +#include "ck/utility/multi_index.hpp" namespace ck { namespace wrapper { @@ -27,16 +28,12 @@ using MemoryTypeEnum = AddressSpaceEnum; // Disable from doxygen docs generation /// @cond // forward declarations -template +template struct Layout; template - + typename UnrolledDescriptorType> struct Tensor; template @@ -45,13 +42,22 @@ struct Slice __host__ __device__ constexpr Slice() : from_(), to_() {} __host__ __device__ constexpr Slice(FromType from, ToType to) : from_(from), to_(to) {} + /** + * \brief Calculate slice range. + * + * \param dim Dimension size. + * \return Slice range. + */ template __host__ __device__ constexpr auto range(const T& dim) const { if constexpr(is_same_v || is_same_v || is_same_v) { - assert(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_) && "Invalid range"); + if(!(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_))) + { + throw std::runtime_error("Invalid range"); + } if(to_ < 0) { return dim - from_ + to_ + 1; @@ -101,40 +107,27 @@ using is_tuple = decltype(std::declval().IsTuple()); template + typename UnrolledDescriptorType> constexpr auto make_tensor(ElementType* pointer, - const Layout& layout) + const Layout& layout) { - return Tensor(pointer, layout); + return Tensor(pointer, layout); } /** * \brief Make SGPR or VGPR tensor function. * * \tparam MemoryType Type of memory. - * \tparam NumVectors Number of vectors. - * \tparam ScalarPerVector Scalars per vector. * \tparam ElementType Memory data type. * \return Constructed tensor. */ template -constexpr auto make_register_tensor() + typename ElementType, + typename Shape, + typename UnrolledDescriptorType> +constexpr auto make_register_tensor(const Layout& layout) { - const auto layout = make_layout(make_tuple(Number{}), make_tuple(Number<1>{})); - return Tensor>, - std::remove_const_t>, - NumVectors, - ScalarPerVector>(layout); + return Tensor(layout); } /** @@ -146,15 +139,9 @@ constexpr auto make_register_tensor() template -__host__ __device__ constexpr const auto& layout(const Tensor& tensor) + typename UnrolledDescriptorType> +__host__ __device__ constexpr const auto& +layout(const Tensor& tensor) { return tensor.GetLayout(); } @@ -170,15 +157,9 @@ template -__host__ __device__ constexpr auto size(const Tensor& tensor) + typename UnrolledDescriptorType> +__host__ __device__ constexpr auto +size(const Tensor& tensor) { return size(tensor.GetLayout()); } @@ -194,15 +175,9 @@ template -__host__ __device__ constexpr auto rank(const Tensor& tensor) + typename UnrolledDescriptorType> +__host__ __device__ constexpr auto +rank(const Tensor& tensor) { return rank(tensor.GetLayout()); } @@ -218,15 +193,9 @@ template -__host__ __device__ constexpr auto depth(const Tensor& tensor) + typename UnrolledDescriptorType> +__host__ __device__ constexpr auto +depth(const Tensor& tensor) { return depth(tensor.GetLayout()); } @@ -240,15 +209,9 @@ __host__ __device__ constexpr auto depth(const Tensor -__host__ __device__ constexpr const auto& shape(const Tensor& tensor) + typename UnrolledDescriptorType> +__host__ __device__ constexpr const auto& +shape(const Tensor& tensor) { return shape(tensor.GetLayout()); } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp index 56b0ce7914..750d4d14f8 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -10,6 +10,7 @@ #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/numeric.hpp" namespace ck { namespace tensor_operation { diff --git a/test/wrapper/test_copy.cpp b/test/wrapper/test_copy.cpp index 5cf09a54be..e7fa3c539b 100644 --- a/test/wrapper/test_copy.cpp +++ b/test/wrapper/test_copy.cpp @@ -21,49 +21,59 @@ template + bool UseOptimizedCopy> __global__ void TestCopyDevice(const InputTensor input_tensor, OutputTensor output_tensor, const BlockShape tile_shape, - const ThreadLayoutShape thread_layout, - const LocalTileSteps block_steps, - const LocalPartitionSteps thread_steps) + const ThreadLayoutShape thread_layout) { __shared__ ck::index_t p_shared[ck::wrapper::size(tile_shape)]; - auto tensor_lds = ck::wrapper::make_tensor( + const auto tensor_lds = ck::wrapper::make_tensor( p_shared, ck::wrapper::make_layout(tile_shape)); - const auto block_idxs = ck::make_tuple(ck::make_tuple(0, 0), blockIdx.x); + const auto block_idx = static_cast(blockIdx.x); // Get local tiles for global memory - const auto input_local_tile = - ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs, block_steps); + const auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx); const auto output_local_tile = - ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs, block_steps); + ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx); // Get partition per thread - const auto input_local_partition = ck::wrapper::make_local_partition( - input_local_tile, thread_layout, threadIdx.x, thread_steps); + const auto input_local_partition = + ck::wrapper::make_local_partition(input_local_tile, thread_layout, threadIdx.x); auto lds_local_partition = - ck::wrapper::make_local_partition(tensor_lds, thread_layout, threadIdx.x, thread_steps); - auto output_local_partition = ck::wrapper::make_local_partition( - output_local_tile, thread_layout, threadIdx.x, thread_steps); + ck::wrapper::make_local_partition(tensor_lds, thread_layout, threadIdx.x); + auto output_local_partition = + ck::wrapper::make_local_partition(output_local_tile, thread_layout, threadIdx.x); // Allocate VGPR - constexpr ck::index_t scalar_per_vector = 1; - constexpr ck::index_t vgpr_size = ck::wrapper::size(lds_local_partition); - auto tensor_vgpr = ck::wrapper::make_register_tensor(); + auto tensor_vgpr = + ck::wrapper::make_register_tensor( + layout(lds_local_partition)); // Perform copy - ck::wrapper::copy(input_local_partition, lds_local_partition); - ck::wrapper::copy(lds_local_partition, tensor_vgpr); - ck::wrapper::copy(tensor_vgpr, output_local_partition); + if constexpr(UseOptimizedCopy) + { + using DimAccessOrder = ck::Tuple, ck::Number<0>>; + constexpr ck::index_t vector_dim = 0; + constexpr ck::index_t scalar_per_vector = 2; + ck::wrapper::copy(input_local_partition, + lds_local_partition); + // TODO: Enable optimized copy for static buffers + ck::wrapper::copy(lds_local_partition, + tensor_vgpr); + ck::wrapper::copy(tensor_vgpr, + output_local_partition); + } + else + { + ck::wrapper::copy(input_local_partition, lds_local_partition); + ck::wrapper::copy(lds_local_partition, tensor_vgpr); + ck::wrapper::copy(tensor_vgpr, output_local_partition); + } } +template void PerformCopyGlobalToGlobalViaLDS() { const auto shape = @@ -89,15 +99,8 @@ void PerformCopyGlobalToGlobalViaLDS() auto output_tensor_global = ck::wrapper::make_tensor( static_cast(out_buf.GetDeviceBuffer()), layout); - const auto thread_layout = - ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}), ck::Number<32>{}); - const auto tile_shape = - ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<64>{}); - - const auto thread_steps = - ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}), ck::Number<2>{}); - const auto block_steps = - ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}), ck::Number<64>{}); + const auto thread_layout = ck::make_tuple(ck::Number<1>{}, ck::Number<32>{}); + const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}); const ck::index_t grid_size = ck::math::integer_divide_ceil( ck::wrapper::size(input_tensor_global), ck::wrapper::size(tile_shape)); @@ -106,8 +109,7 @@ void PerformCopyGlobalToGlobalViaLDS() decltype(output_tensor_global), decltype(tile_shape), decltype(thread_layout), - decltype(block_steps), - decltype(thread_steps)>; + UseOptimizedCopy>; launch_and_time_kernel(StreamConfig{}, kernel, dim3(grid_size), @@ -116,9 +118,7 @@ void PerformCopyGlobalToGlobalViaLDS() input_tensor_global, output_tensor_global, tile_shape, - thread_layout, - block_steps, - thread_steps); + thread_layout); // Verify results std::vector output_data(ck::wrapper::size(shape)); @@ -126,4 +126,5 @@ void PerformCopyGlobalToGlobalViaLDS() EXPECT_TRUE(ck::utils::check_err(output_data, input_data)); } -TEST(TestCopy, CopyGlobalToGlobalViaLDS) { PerformCopyGlobalToGlobalViaLDS(); } +TEST(TestCopyGlobalToGlobalViaLDS, GenericCopy) { PerformCopyGlobalToGlobalViaLDS(); } +TEST(TestCopyGlobalToGlobalViaLDS, OptimizedCopy) { PerformCopyGlobalToGlobalViaLDS(); } diff --git a/test/wrapper/test_partition.cpp b/test/wrapper/test_partition.cpp index df56b879f6..cacbfe9d88 100644 --- a/test/wrapper/test_partition.cpp +++ b/test/wrapper/test_partition.cpp @@ -29,42 +29,29 @@ TEST(TestPartition, LocalPartition) const auto tensor = ck::wrapper::make_tensor(data.data(), layout); - const auto thread_steps = - ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<1>{}), ck::Number<1>{}); - const auto thread_layout = - ck::make_tuple(ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}), ck::Number<1>{}); - - for(ck::index_t thread_id = 0; thread_id < ck::wrapper::size(thread_layout); thread_id++) - { - const auto raked_partition = - ck::wrapper::make_local_partition(tensor, thread_layout, thread_id); - - const auto expected_partition_size = - ck::wrapper::size(tensor) / ck::wrapper::size(thread_layout); - EXPECT_EQ(ck::wrapper::size(raked_partition), expected_partition_size); - EXPECT_EQ(raked_partition(0), thread_id); - } + const auto thread_steps = ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}); + const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}); for(ck::index_t thread_id = 0; thread_id < ck::wrapper::size(thread_layout); thread_id++) { const auto packed_partition = - ck::wrapper::make_local_partition(tensor, thread_layout, thread_id, thread_steps); + ck::wrapper::make_local_partition(tensor, thread_layout, thread_id); const auto expected_partition_size = ck::wrapper::size(tensor) / ck::wrapper::size(thread_layout); - const auto expected_partition_first_val = thread_id * ck::wrapper::size<0, 0>(thread_steps); + const auto expected_partition_first_val = thread_id * ck::wrapper::size<0>(thread_steps); + const auto expected_partition_second_val = expected_partition_first_val + 1; EXPECT_EQ(ck::wrapper::size(packed_partition), expected_partition_size); EXPECT_EQ(packed_partition(0), expected_partition_first_val); + EXPECT_EQ(packed_partition(1), expected_partition_second_val); } } TEST(TestPartition, LocalTile) { - const auto shape = - ck::make_tuple(ck::make_tuple(ck::Number<16>{}, ck::Number<4>{}), ck::Number<4>{}); - const auto strides = - ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<16>{}), ck::Number<64>{}); - const auto layout = ck::wrapper::make_layout(shape, strides); + const auto shape = ck::make_tuple(ck::Number<16>{}, ck::Number<4>{}, ck::Number<4>{}); + const auto strides = ck::make_tuple(ck::Number<1>{}, ck::Number<16>{}, ck::Number<64>{}); + const auto layout = ck::wrapper::make_layout(shape, strides); std::vector data(ck::wrapper::size(layout)); std::iota(data.begin(), data.end(), 0); @@ -72,48 +59,34 @@ TEST(TestPartition, LocalTile) const auto tensor = ck::wrapper::make_tensor(data.data(), layout); - const auto block_steps = - ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}), ck::Number<2>{}); - const auto block_shape = - ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}), ck::Number<2>{}); - const auto block_layout = - ck::make_tuple(ck::make_tuple(ck::Number<4>{}, ck::Number<2>{}), ck::Number<2>{}); + const auto block_shape = ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}); + const auto num_blocks = + ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape), + ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape), + ck::wrapper::size<2>(shape) / ck::wrapper::size<2>(block_shape)); + std::vector block_idxs(ck::wrapper::size(num_blocks)); + std::iota(block_idxs.begin(), block_idxs.end(), 0); - std::vector, ck::index_t>> block_idxs; - for(ck::index_t x = 0; x < ck::wrapper::size<0, 0>(block_layout); x++) + for(auto block_idx : block_idxs) { - for(ck::index_t y = 0; y < ck::wrapper::size<0, 1>(block_layout); y++) - { - for(ck::index_t z = 0; z < ck::wrapper::size<1>(block_layout); z++) - { - block_idxs.emplace_back(ck::make_tuple(x, y), z); - } - } - } - - for(const auto& block_idx : block_idxs) - { - const auto raked_tile = ck::wrapper::make_local_tile(tensor, block_shape, block_idx); + const auto packed_tile = ck::wrapper::make_local_tile(tensor, block_shape, block_idx); const auto expected_tile_size = ck::wrapper::size(block_shape); - EXPECT_EQ(ck::wrapper::size(raked_tile), expected_tile_size); - EXPECT_EQ(raked_tile(0), layout(block_idx)); - } + auto expected_tile_first_val = (block_idx % ck::wrapper::size<2>(num_blocks)) * + ck::wrapper::size<2>(block_shape) * + ck::wrapper::size<2>(strides); + block_idx /= ck::wrapper::size<2>(num_blocks); + expected_tile_first_val += (block_idx % ck::wrapper::size<1>(num_blocks)) * + ck::wrapper::size<1>(block_shape) * + ck::wrapper::size<1>(strides); + block_idx /= ck::wrapper::size<1>(num_blocks); + expected_tile_first_val += (block_idx % ck::wrapper::size<0>(num_blocks)) * + ck::wrapper::size<0>(block_shape) * + ck::wrapper::size<0>(strides); - for(const auto& block_idx : block_idxs) - { - const auto packed_tile = - ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_steps); - - const auto expected_tile_size = ck::wrapper::size(block_shape); - const auto expected_tile_first_val = - ck::wrapper::size<0, 0>(block_idx) * ck::wrapper::size<0, 0>(block_shape) * - ck::wrapper::size<0, 0>(strides) + - ck::wrapper::size<0, 1>(block_idx) * ck::wrapper::size<0, 1>(block_shape) * - ck::wrapper::size<0, 1>(strides) + - ck::wrapper::size<1>(block_idx) * ck::wrapper::size<1>(block_shape) * - ck::wrapper::size<1>(strides); + const auto expected_tile_second_val = expected_tile_first_val + 1; EXPECT_EQ(ck::wrapper::size(packed_tile), expected_tile_size); EXPECT_EQ(packed_tile(0), expected_tile_first_val); + EXPECT_EQ(packed_tile(1), expected_tile_second_val); } } diff --git a/test/wrapper/test_tensor.cpp b/test/wrapper/test_tensor.cpp index 2d4d6f2750..3c7d877528 100644 --- a/test/wrapper/test_tensor.cpp +++ b/test/wrapper/test_tensor.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -100,31 +100,26 @@ TEST(TestTensor, ReadWriteHostMemory) __global__ void TestTensorReadWriteDevice(void* data, void* success) { - constexpr ck::index_t nelems = 8; - constexpr ck::index_t scalar_per_vector = 1; + constexpr ck::index_t nelems = 8; __shared__ ck::index_t p_shared[nelems]; ck::index_t* casted_data_ptr = static_cast(data); bool* casted_success_ptr = static_cast(success); const auto layout = ck::wrapper::make_layout(ck::make_tuple(ck::make_tuple(2, 2), 2)); + constexpr auto vgpr_layout = + ck::wrapper::make_layout(make_tuple(ck::Number{}), make_tuple(ck::Number<1>{})); auto tensor_global = ck::wrapper::make_tensor(casted_data_ptr, layout); - auto tensor_lds = ck::wrapper::make_tensor(p_shared, layout); - auto tensor_vgpr = ck::wrapper::make_register_tensor(); - auto tensor_sgpr = ck::wrapper::make_register_tensor(); + auto tensor_lds = ck::wrapper::make_tensor(p_shared, layout); + auto tensor_vgpr = + ck::wrapper::make_register_tensor( + vgpr_layout); InitTensor(tensor_global); InitTensor(tensor_lds); StaticInitTensor(tensor_vgpr); - StaticInitTensor(tensor_sgpr); *casted_success_ptr = TestTensorCheck1d(tensor_global); *casted_success_ptr &= TestTensorCheck3d(tensor_global); @@ -133,8 +128,6 @@ __global__ void TestTensorReadWriteDevice(void* data, void* success) *casted_success_ptr &= TestTensorCheck3d(tensor_lds); *casted_success_ptr &= StaticTestTensorCheck1d(tensor_vgpr); - - *casted_success_ptr &= StaticTestTensorCheck1d(tensor_sgpr); } TEST(TestTensor, ReadWriteGlobalLdsRegistersMemory) From bb63b9732cf179330ea961e8fbf49ea267827f16 Mon Sep 17 00:00:00 2001 From: Haocong WANG Date: Fri, 19 Jan 2024 21:02:22 +0800 Subject: [PATCH 56/75] [GEMM] Optimization for MI200/300. (#1135) * Optimize GEMM on MI200/300: 1. Add new blockwise gemm pipeline 2. Add irregular splitk intances * clang format + typo fix * Fix a bug --- example/01_gemm/CMakeLists.txt | 3 + example/01_gemm/gemm_xdl_fp16_v2.cpp | 51 + .../run_splitK_gemm_example.inc | 2 +- .../35_splitK_gemm/splitK_gemm_xdl_fp16.cpp | 2 +- include/ck/stream_config.hpp | 4 +- .../block/blockwise_gemm_pipeline_xdlops.hpp | 999 ++++++++++++++ .../impl/device_gemm_xdl_cshuffle_v2.hpp | 306 +++++ .../gpu/grid/block_to_ctile_map.hpp | 301 +++++ .../grid/gridwise_gemm_xdl_cshuffle_v2.hpp | 1153 +++++++++++++++++ .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 30 + ...uffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp | 4 +- ...uffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp | 4 +- ..._shuffle_f16_f16_f16_mk_kn_mn_instance.cpp | 4 +- ..._shuffle_f16_f16_f16_mk_nk_mn_instance.cpp | 4 +- ...l_splitk_f16_f16_f16_mk_kn_mn_instance.cpp | 81 +- ...l_splitk_f16_f16_f16_mk_nk_mn_instance.cpp | 82 ++ .../profiler/profile_gemm_splitk_impl.hpp | 2 +- 17 files changed, 3015 insertions(+), 17 deletions(-) create mode 100644 example/01_gemm/gemm_xdl_fp16_v2.cpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 56897571c7..5b71cd1548 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -19,6 +19,9 @@ add_custom_target(example_gemm_xdl) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16) +add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2) + add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) diff --git a/example/01_gemm/gemm_xdl_fp16_v2.cpp b/example/01_gemm/gemm_xdl_fp16_v2.cpp new file mode 100644 index 0000000000..eba0ea9d11 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_v2.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using F16 = ck::half_t; +using F32 = float; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV2< + ALayout, BLayout, CLayout, + F16, F16, F16, F32, F16, + PassThrough, PassThrough, PassThrough, GemmDefault, + 2, 256, + 256, 256, + 32, 8, 4, + 32, 32, + 4, 4, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 4, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::LoopScheduler::Default, ck::PipelineVersion::v1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/35_splitK_gemm/run_splitK_gemm_example.inc b/example/35_splitK_gemm/run_splitK_gemm_example.inc index e9bd5c552d..e3690984ab 100644 --- a/example/35_splitK_gemm/run_splitK_gemm_example.inc +++ b/example/35_splitK_gemm/run_splitK_gemm_example.inc @@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con if(config.time_kernel) { - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 1}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp index 74fb16e15b..dc54bc30ef 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp @@ -42,7 +42,7 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::KPadding; using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle // clang-format off diff --git a/include/ck/stream_config.hpp b/include/ck/stream_config.hpp index e6a6808244..a5b1407305 100644 --- a/include/ck/stream_config.hpp +++ b/include/ck/stream_config.hpp @@ -11,6 +11,6 @@ struct StreamConfig hipStream_t stream_id_ = nullptr; bool time_kernel_ = false; int log_level_ = 0; - int cold_niters_ = 1; - int nrepeat_ = 10; + int cold_niters_ = 5; + int nrepeat_ = 50; }; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp new file mode 100644 index 0000000000..7b2aaa76bb --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -0,0 +1,999 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +// Double LDS buffer +// Prefetech 2 stage +// Local prefetch 1 stage + +namespace ck { + +template +struct BlockwiseGemmXdlops_pipeline_hotloop_inst +{ + static constexpr index_t WaveSize = 64; + static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL); + + static constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth); + static constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth); + + static constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth); + static constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth); + + static constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth); + static constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth); + + static constexpr index_t C_MFMA_Inst_Num = + MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + static constexpr auto Print() + { + printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n", + BlockSize, + WaveSize, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + KPerXDL); + + printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " + "%d, %d\n C MFMA inst: %d\n", + A_Buffer_Load_Inst_Num, + B_Buffer_Load_Inst_Num, + A_LDS_Write_Inst_Num, + B_LDS_Write_Inst_Num, + A_LDS_Read_Inst_Num, + B_LDS_Read_Inst_Num, + C_MFMA_Inst_Num); + } +}; + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename ATileDesc, + typename BTileDesc, + typename AMmaTileDesc, + typename BMmaTileDesc, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t MPerXDL, + index_t NPerXDL, + index_t MRepeat, + index_t NRepeat, + index_t KPack, + bool TransposeC = false, + index_t AMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops, + index_t BMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops> +struct BlockwiseGemmXdlops_pipeline_v4 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); + static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = + XdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + static constexpr index_t KRepeat = KPerThread / KPack; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + using HotLoopInstList = BlockwiseGemmXdlops_pipeline_hotloop_inst; + + static_assert(KPerThread % KPack == 0, + "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple( + m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]); + } + + using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + + __host__ __device__ + BlockwiseGemmXdlops_pipeline_v4(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + + // HotLoopInstList::Print(); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, N, M0, M1, M2)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __device__ static constexpr auto HotLoopScheduler() + { + // schedule + constexpr auto num_ds_read_inst = + HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst = + HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num; + ; + constexpr auto num_buffer_load_inst = + HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num; + ; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto num_issue = num_buffer_load_inst; + + static_for<0, num_issue, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier( + 0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier( + 0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA + }); + } + + template + __device__ static constexpr auto TailScheduler() + { + } + + template <> + __device__ static constexpr auto TailScheduler<1>() + { + // schedule + constexpr auto num_ds_read_inst = + HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst = + HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num; + ; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto num_issue = num_ds_write_inst; + + static_for<0, num_issue, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier( + 0x100, num_ds_read_inst / num_ds_write_inst - 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_inst / num_ds_write_inst - 3, 0); // MFMA + }); + } + + template <> + __device__ static constexpr auto TailScheduler<2>() + { + // schedule + constexpr auto num_ds_read_inst = + HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto num_issue = num_ds_read_inst; + + static_for<0, num_issue, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_inst / num_ds_read_inst, 0); // MFMA + }); + } + + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs; + // Inst List: + // ds_read_b128: 16 + // ds_write_b128: 8 + // buffer_load_dwordx4: 16 + // v_mfma: 0 + // ------------------------------------------------------------------------------------------- + + // Global prefetch 1th, Fill Ping LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0)); + + // Local prefetch 1th, Fill Ping Reg + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(I0)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); + }); + }); + }); + + // Global prefetch 2th, Fill Pong LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1)); + + // Global prefetch 3rd + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + // This hot loop has two legacy loopover, to implement the double local buffer strategy + do + { + // ------------------------------------------------------------------------------------------- + using PingP1 = Number<0>; + using PongP1 = Number<1>; + // MFMA: Ping Reg + // DS_WRITE: To Ping LDS + // DS_READ: Pong LDS to Pong Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{})); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{})); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + // ------------------------------------------------------------------------------------------- + using PingP2 = Number<1>; + using PongP2 = Number<0>; + // MFMA: Pong Reg + // DS_WRITE: To Pong LDS + // DS_READ: Ping LDS to Ping Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP2{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP2{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP2{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP2{})); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP2{})); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP2{})); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 2; + } while(i < (num_loop - 3)); + } + + // tail + if constexpr(TailNum == 3) + { + using PingP1 = Number<0>; + using PongP1 = Number<1>; + // MFMA: Ping Reg + // DS_WRITE: To Ping LDS + // DS_READ: Pong LDS to Pong Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{})); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{})); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + TailScheduler<1>(); + __builtin_amdgcn_sched_barrier(0); + + // ------------------------------------------------------------------------------------------- + using PingP2 = Number<1>; + using PongP2 = Number<0>; + // MFMA: Pong Reg + // DS_WRITE: To Pong LDS + // DS_READ: Ping LDS to Ping Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP2{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP2{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP2{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP2{})); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + TailScheduler<2>(); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PongP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PongP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + // 64 v_mfma + __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == 2) + { + using PingP1 = Number<0>; + using PongP1 = Number<1>; + // MFMA: Ping Reg + // DS_WRITE: To Ping LDS + // DS_READ: Pong LDS to Pong Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + TailScheduler<2>(); + __builtin_amdgcn_sched_barrier(0); + + // ------------------------------------------------------------------------------------------- + using PingP2 = Number<1>; + // MFMA: Pong Reg + // DS_WRITE: To Pong LDS + // DS_READ: Ping LDS to Ping Reg + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + // 64 v_mfma + __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA + __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + // M1, N1 as double buffer index + // Read buffer + Compute buffer + // A[M0, M1, M2, KPack] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // B[N0, N1, N2, KPack] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp new file mode 100644 index 0000000000..d49c63f147 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp @@ -0,0 +1,306 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm +{ + using DeviceOp = DeviceGemm_Xdl_CShuffleV2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v2< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + float ave_time = 0; + const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1; + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K) == 3) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v2; + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v2; + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemm_Xdl_CShuffleV2" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 7bb47e9d3c..6266fb40f0 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -134,6 +134,11 @@ struct BlockToCTileMap_M00_N0_M01Adapt __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) : M_(M), N_(N), M01_(M01) { +#if 0 + if(get_thread_global_1d_id()==0){ + printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_); + } +#endif } template @@ -252,6 +257,302 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt +struct BlockToCTileMap_Grouped_M00_N0_M01Adapt; + +template +struct BlockToCTileMap_Grouped_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default; + + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt( + const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default; + __host__ __device__ + BlockToCTileMap_Grouped_M00_N0_M01Adapt(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default; + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt& + operator=(const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default; + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt& + operator=(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default; + + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M, + index_t N, + index_t M01 = 8) + : M_(M), N_(N), M01_(M01) + { +#if 0 + if(get_thread_global_1d_id()==0){ + printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_); + } +#endif + } + + template + __host__ __device__ + BlockToCTileMap_Grouped_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8) + : BlockToCTileMap_Grouped_M00_N0_M01Adapt( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01) + { + } + + __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + + template + __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + + const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum); + auto group_id = block_1d_id % GroupNum; + auto remap_block_1d_id = group_id * group_size + block_1d_id / GroupNum; + + index_t idx_N0 = remap_block_1d_id % N0; + index_t idx_M0 = remap_block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + /** + * idxN0 + * + * |< mtx N >| + * + * NPerBlock NPerBlock NPerBlock NPerBlock + * N_0 N_1 N_2 N_3 + * - |-----------|-----------|-----------|-----|-----|- + * ^ | - - 0 |/----> 2 | | | | + * | | | / | | | | | M_0 MPerBlock + * | M | /| | | | | | + * |-0---|---/-|-----|-----|-----------|-----|-----|- + * | 1 | / | | | blockid | | | + * idxM0 | | | / | V | 5 | | | M_1 MPerBlock + * | - V 1 | - 3 | | | | + * |-----------|-----------|-----------|-----|-----|- + * mtx M | | | | | | + * | | | | | | M_2 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * | | | | | | + * | | | | | | M_3 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * V | | | | | | + * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * Example: + * assume: + * M0 = 5 + * N0 = 4 + * block_1d_id = 5 + * M01 = 2 + * + * idx_N0 = 1 + * idx_M0 = 1 + * M01_adapt = 2 + * idx_M00 = 0 + * idx_M01 = 1 + * idx_N0_M01_local = 5 + * output {1, 2} + */ + + return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t M01_; +}; + +// keep the redundant type argument for backward compatibility +template +struct BlockToCTileMap_Grouped_M00_N0_M01Adapt + : BlockToCTileMap_Grouped_M00_N0_M01Adapt +{ + using BlockToCTileMap_Grouped_M00_N0_M01Adapt:: + BlockToCTileMap_Grouped_M00_N0_M01Adapt; +}; + +// columns of row-vectors +// This C-tile map dynamically adjusts N01 when C-tile index is out of range +template +struct BlockToCTileMap_N00_M0_N01Adapt; + +template +struct BlockToCTileMap_N00_M0_N01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt() = default; + + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const BlockToCTileMap_N00_M0_N01Adapt&) = + default; + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(BlockToCTileMap_N00_M0_N01Adapt&&) = + default; + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt& + operator=(const BlockToCTileMap_N00_M0_N01Adapt&) = default; + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt& + operator=(BlockToCTileMap_N00_M0_N01Adapt&&) = default; + + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01 = 8) + : M_(M), N_(N), N01_(N01) + { +#if 0 + if(get_thread_global_1d_id()==0){ + printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_); + } +#endif + } + + template + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t N01 = 8) + : BlockToCTileMap_N00_M0_N01Adapt( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), N01) + { + } + + __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + + template + __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + + index_t idx_M0 = block_1d_id % M0; + index_t idx_N0 = block_1d_id / M0; + + const auto N01_adapt = (idx_N0 < N0 - N0 % N01_) ? N01_ : N0 % N01_; + + index_t idx_N00 = idx_N0 / N01_; + index_t idx_N01 = idx_N0 % N01_; + index_t idx_M0_N01_local = idx_M0 + idx_N01 * M0; + + /** + * idxN0 + * + * |< mtx N >| + * + * |<---N01--->| + * - |-----------|-----------|-----------|-----|-----|- + * ^ | 0 ----------> 1 | | | | + * | | / | | | | M_0 MPerBlock + * | / | | | | + * |------/----------------|-----------|-----|-----|- + * | | | | | | | + * idxM0 | V | | | | | M_1 MPerBlock + * | 2 ----------> 3 | | | | + * |-----------|-----------|-----------|-----|-----|- + * mtx M | | blockid | | | | + * | | 5 | | | | M_2 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * | | | | | | + * | | | | | | M_3 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * V | | | | | | + * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * NPerBlock NPerBlock NPerBlock NPerBlock + * N_0 N_1 N_2 N_3 + * Example: + * assume: + * N0 = 5 + * M0 = 4 + * block_1d_id = 5 + * N01 = 2 + * + * idx_M0 = 1 + * idx_N0 = 1 + * N01_adapt = 2 + * idx_N00 = 0 + * idx_N01 = 1 + * idx_M0_N01_local = 5 + * output {2, 1} + */ + + return make_tuple(idx_M0_N01_local / N01_adapt, + idx_M0_N01_local % N01_adapt + idx_N00 * N01_); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t N01_; +}; + // 2D slices of column-vectors in 3D space // This C-tile map dynamically adjusts M01 when C-tile index is out of range template diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp new file mode 100644 index 0000000000..2ad2dd9915 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -0,0 +1,1153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +#endif + kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + typename GridwiseGemm::Problem problem) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = problem; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_v2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + return CalculateKPadded(K) / AK1Value; + } + else + { + return K / AK1Value; + } + } + + __host__ static auto CalculateBK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + return CalculateKPadded(K) / BK1Value; + } + else + { + return K / BK1Value; + } + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_floor(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_floor(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KPadded{CalculateKPadded(K_)}, + AK0{CalculateAK0(K_)}, + BK0{CalculateBK0(K_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const FloatA* p_a_grid; + const FloatB* p_b_grid; + FloatC* p_c_grid; + }; + + // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) + + b_block_space_size_aligned * sizeof(ComputeTypeB)), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding) + { + if(!(CalculateKPadded(problem.K) % AK1Value == 0) || + !(CalculateKPadded(problem.K) % BK1Value == 0)) + { + return false; + } + } + else + { + if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock; + + if(num_k_loop < 4) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return num_loop > 3; + } + + __host__ static constexpr index_t CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + if(num_loop % 2 == 1) + return 3; + else + return 2; + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } +#if 0 + if(threadIdx.x == 0){ + printf("Hardware assigned No. %03d workgroup of logical C tile (%02d, %02d) on %d th XCC Die, %d th SE, %d th CU\n", + get_block_1d_id(), + block_work_idx[I0], + block_work_idx[I1], + __smid()>>6 & 0xf, + __smid()>>4 & 0x3, + __smid() & 0xf); + } +#endif + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatA, + ComputeTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatB, + ComputeTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + // BlockSize, + // ComputeType, + // FloatGemmAcc, + // decltype(a_block_desc_ak0_m_ak1), + // decltype(b_block_desc_bk0_n_bk1), + // MPerXdl, + // NPerXdl, + // MXdlPerWave, + // NXdlPerWave, + // KPack, + // LoopSched>(); + auto blockwise_gemm_pipeline = BlockwiseGemmXdlops_pipeline_v4< + BlockSize, + ComputeTypeA, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)), + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>{}; // TransposeC + + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // gridwise GEMM pipeline + static_assert(std::is_default_constructible_v); + // const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 7bab488e58..87e1e0e8d9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -268,6 +268,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding) + { + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } else { return transform_tensor_descriptor( @@ -329,6 +344,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding) + { + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } else { return transform_tensor_descriptor( diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp index 06a117919a..d547b3e602 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -7,6 +7,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple< DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemm_Xdl_CShuffle< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV2< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves , diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp index 48351b2f29..60d4ccf525 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -7,6 +7,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -52,7 +53,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances = std::tuple< DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemm_Xdl_CShuffle< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV2< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves , diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp index ad846e4c80..4a2526b3a4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -7,6 +7,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -57,7 +58,8 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV2< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves , diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp index 3c50cf2273..01e0ebdb34 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -7,6 +7,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -52,7 +53,8 @@ using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1> + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV2< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 2, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopScheduler::Default, PipelineVersion::v1> #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves , diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp index 9fd83cdec8..45096f659f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -27,6 +27,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; @@ -110,17 +111,39 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< // clang-format on >; -template +template using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances = std::tuple< // clang-format off //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave> + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche> // clang-format on >; @@ -141,9 +164,51 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( add_device_operation_instances( instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{}); - add_device_operation_instances( - instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp index 25c94bb886..b22f4a3beb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -27,6 +27,7 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; @@ -95,6 +96,41 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2> // clang-format on >; +template +using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche> + // clang-format on + >; void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( std::vector{}); + + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); } } // namespace instance diff --git a/profiler/include/profiler/profile_gemm_splitk_impl.hpp b/profiler/include/profiler/profile_gemm_splitk_impl.hpp index 6816d2c538..5d5ae1ad15 100644 --- a/profiler/include/profiler/profile_gemm_splitk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_splitk_impl.hpp @@ -145,7 +145,7 @@ bool profile_gemm_splitk_impl(int do_verification, // profile device GEMM instances for(auto& op_ptr : op_ptrs) { - std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 20, 32, 36, 40, 64, 96, 128}; + std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 19, 20, 32, 38}; if(KBatch > 0) { From 1be4706366f77b4da7e5114e4e334e9aa7dd3b62 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Mon, 22 Jan 2024 10:42:26 -0600 Subject: [PATCH 57/75] fixed return (#1138) --- profiler/include/profiler/profile_gemm_impl.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/profiler/include/profiler/profile_gemm_impl.hpp b/profiler/include/profiler/profile_gemm_impl.hpp index 586a356ecc..0419ccd8e7 100644 --- a/profiler/include/profiler/profile_gemm_impl.hpp +++ b/profiler/include/profiler/profile_gemm_impl.hpp @@ -298,7 +298,7 @@ int profile_gemm_impl(int do_verification, } } - return pass ? 0 : 1; + return pass; } } // namespace profiler From 6169fbbdb3a55a6458c8498276f04e500a0893a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 24 Jan 2024 17:19:02 +0100 Subject: [PATCH 58/75] Fix possible linting errors in changelog (#1141) * Fix possible linting errors in changelog * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md --- CHANGELOG.md | 58 ++++++++++++++++++++++++++-------------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 12cc4363de..c721039523 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ None None ### Additions -- Introduce wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126) +* Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126) ### Changes None @@ -19,49 +19,49 @@ None ## CK for ROCm 6.0.0 ### Fixes - - Fixed a hazard associated with inline v_dot (#808) - - Fixed two bugs in grouped convolution backward data without K padding (#848 #876) + * Fixed a hazard associated with inline v_dot (#808) + * Fixed two bugs in grouped convolution backward data without K padding (#848 #876) ### Optimizations None ### Additions -- Added an image to a column kernel (#867) -- Added a column to an image kernel (#930) -- Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985) -- Grouped convolution support for small K and C (#822 #879 #897) -- Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) -- Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) -- Support for Batched Gemm DL (#732) +* Added an image to a column kernel (#867) +* Added a column to an image kernel (#930) +* Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985) +* Grouped convolution support for small K and C (#822 #879 #897) +* Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) +* Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) +* Support for Batched Gemm DL (#732) ### Changes - - Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) + * Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) ## CK 0.2.0 for ROCm 5.7.0 ### Fixes -- Fixed a bug in 6-dimensional kernels (#555) -- Fixed a test case failure with grouped convolution backward weight (#524) +* Fixed a bug in 6-dimensional kernels (#555) +* Fixed a test case failure with grouped convolution backward weight (#524) ### Optimizations -- Improved the performance of the normalization kernel +* Improved the performance of the normalization kernel ### Additions -- New CMake flags: - - "DL_KERNELS"-- Must be set to "ON" in order to build the gemm_dl and batched_gemm_multi_d_dl instances - - "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types - - "INSTANCES_ONLY" -- Only builds CK library and instances without tests, examples, or profiler -- New feature: if GPU_TARGETS is not set in the CMake command line, CK will be built for all targets supported by the compiler -- Support for MI300A/MI300X -- Support for AMD RDNA 3 -- New user tutorial (#563) -- Additional instances for irregular GEMM sizes (#560) -- New inter-wave consumer-producer programming model for GEMM kernels (#310) -- GEMM with support multiple elementwise fusions (multi-D) (#534) -- Multi-embeddings support (#542) -- AMD RDNA 3 blockwise GEMM and real GEMM support (#541) -- AMD RDNA grouped convolution backward weight support (#505) -- MaxPool and AvgPool forward (#815); MaxPool backward (#750) +* New CMake flags: + * "DL_KERNELS"-* Must be set to "ON" in order to build the gemm_dl and batched_gemm_multi_d_dl instances + * "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types + * "INSTANCES_ONLY" -- Only builds CK library and instances without tests, examples, or profiler +* New feature: if GPU_TARGETS is not set in the CMake command line, CK will be built for all targets supported by the compiler +* Support for MI300A/MI300X +* Support for AMD RDNA 3 +* New user tutorial (#563) +* Additional instances for irregular GEMM sizes (#560) +* New inter-wave consumer-producer programming model for GEMM kernels (#310) +* GEMM with support multiple elementwise fusions (multi-D) (#534) +* Multi-embeddings support (#542) +* AMD RDNA 3 blockwise GEMM and real GEMM support (#541) +* AMD RDNA grouped convolution backward weight support (#505) +* MaxPool and AvgPool forward (#815); MaxPool backward (#750) ### Changes None From 180e5720760d201b4bfc15f99f59a311b1bc5562 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 24 Jan 2024 13:47:48 -0800 Subject: [PATCH 59/75] Fixing most of the cppcheck errors. (#1142) * fix cppcheck errors, first pass * fix format * fix returned value in examples * add macro definitions for cppcheck * fix the profile_gemm logic * update the gemm profiler logic * add more difinitions to cppcheck, fix couple more errors * replace runtime error with message in device function * fix a couple of int4 issues * no return for fill function * fix errors in data_types.hpp * fix format * fix few remaining errors * fix errors in data_types.hpp * fix last couple of errors in datat_types.hpp --- Jenkinsfile | 8 ++- example/01_gemm/gemm_dl_int4.cpp | 5 +- example/01_gemm/gemm_xdl_int4.cpp | 5 +- .../gemm_add_add_fastgelu_xdl_int4.cpp | 5 +- .../convnd_fwd_max_xdl_int4.cpp | 5 +- .../batched_gemm_reduce_xdl_fp16.cpp | 9 ++- ...rouped_conv_fwd_bias_relu_add_xdl_int4.cpp | 5 +- .../batched_gemm_gemm_xdl_int4.cpp | 5 +- .../grouped_conv_conv_fwd_xdl_int4.cpp | 5 +- example/48_pool3d_fwd/pool3d_fwd_common.hpp | 4 ++ .../51_avgpool3d_bwd/avgpool3d_bwd_common.hpp | 4 ++ include/ck/utility/data_type.hpp | 65 +++++++++++++++++++ .../cpu/reference_column_to_image.hpp | 2 + .../cpu/reference_conv_bwd_data.hpp | 3 + .../cpu/reference_conv_bwd_weight.hpp | 2 + .../cpu/reference_conv_fwd.hpp | 2 + .../cpu/reference_gemm.hpp | 7 +- .../cpu/reference_image_to_column.hpp | 2 + .../gpu/batched_gemm_gemm.hpp | 3 +- .../gpu/gemm_streamk.hpp | 3 +- profiler/src/profile_gemm.cpp | 12 +++- script/clang-format-overwrite.sh | 2 +- .../test_conv_tensor_rearrange_interface.cpp | 2 + 23 files changed, 125 insertions(+), 40 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index e333a35ecd..9359aa1f69 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -304,7 +304,7 @@ def buildHipClangJob(Map conf=[:]){ gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { - timeout(time: 20, unit: 'HOURS') + timeout(time: 48, unit: 'HOURS') { cmake_build(conf) } @@ -755,7 +755,11 @@ pipeline { -o -iname \'*.cl\' \ | grep -v 'build/' \ | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \ - /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include --file-filter=*.cpp --enable=all --output-file=ck_cppcheck.log" + /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ + -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 -D DL_KERNELS \ + -D __gfx908__ -D __gfx90a__ -D __gfx940__ -D __gfx941__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ + -U __gfx803__ -U __gfx900__ -U __gfx906__ -U CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 \ + --file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log" } steps{ buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) diff --git a/example/01_gemm/gemm_dl_int4.cpp b/example/01_gemm/gemm_dl_int4.cpp index e55ae14013..43c0cfe2e0 100644 --- a/example/01_gemm/gemm_dl_int4.cpp +++ b/example/01_gemm/gemm_dl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -43,3 +41,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +#endif \ No newline at end of file diff --git a/example/01_gemm/gemm_xdl_int4.cpp b/example/01_gemm/gemm_xdl_int4.cpp index f6238c7aa5..fb4f383fae 100644 --- a/example/01_gemm/gemm_xdl_int4.cpp +++ b/example/01_gemm/gemm_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -44,3 +42,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +#endif \ No newline at end of file diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp index f206bbeb41..1d0b0f7861 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -58,3 +56,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; #include "run_convnd_fwd_max_example.inc" int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); } +#endif diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index e363dc5c12..62295c57eb 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -272,15 +272,14 @@ int main(int argc, char* argv[]) { for(int m = 0; m < M; ++m) { - auto reduce0_acc = reduce0_op.GetIdentityValue(); - auto reduce1_acc = reduce1_op.GetIdentityValue(); - + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); + ReduceAccDataType d0_val = 0; + ReduceAccDataType d1_val = 0; for(int n = 0; n < N; ++n) { auto c_val = ck::type_convert(c_g_m_n_host_result(batch, m, n)); - ReduceAccDataType d0_val; - ReduceAccDataType d1_val; UnaryIdenticElementOp{}(d0_val, c_val); UnarySquareElementOp{}(d1_val, c_val); diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp index 5494563fdd..6f91d51a5f 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -29,3 +27,4 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; #include "run_grouped_conv_fwd_bias_relu_add_example.inc" int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } +#endif diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp index d166214c33..2caee6b8dc 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp @@ -9,9 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o Gemm1 */ -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include #include @@ -144,3 +142,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); #endif int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; } +#endif diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp index 80f6e9ae05..cf7b1ce3a8 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include #include @@ -120,3 +118,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); #endif int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; } +#endif diff --git a/example/48_pool3d_fwd/pool3d_fwd_common.hpp b/example/48_pool3d_fwd/pool3d_fwd_common.hpp index 39032fa123..788f38ec52 100644 --- a/example/48_pool3d_fwd/pool3d_fwd_common.hpp +++ b/example/48_pool3d_fwd/pool3d_fwd_common.hpp @@ -32,6 +32,8 @@ std::vector f_tensor_strides_ncdhw(ck::index_t N_, return {C_ * D * H * W, D * H * W, H * W, W, 1_uz}; else if constexpr(ck::is_same::value) return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}; + throw std::runtime_error("Pool3d_fwd: problem with layout. "); + return {0, 0, 0, 0, 0}; }; template @@ -53,6 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, return HostTensorDescriptor({N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); } + throw std::runtime_error("Pool3d_fwd: problem with layout. "); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); }; template f_tensor_strides_ncdhw(ck::index_t N_, return {C_ * D * H * W, D * H * W, H * W, W, 1_uz}; else if constexpr(ck::is_same::value) return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}; + throw std::runtime_error("Avgpool3d_bwd: problem with layout. "); + return {0, 0, 0, 0, 0}; }; template @@ -47,6 +49,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, return HostTensorDescriptor({N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); } + throw std::runtime_error("Avgpool3d_bwd: problem with layout. "); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); }; template } }; +int static err = 0; template struct vector_type { @@ -221,6 +222,10 @@ struct vector_type { return data_.d2x1_; } + else + { + return err; + } } template @@ -236,6 +241,10 @@ struct vector_type { return data_.d2x1_; } + else + { + return err; + } } }; @@ -278,6 +287,10 @@ struct vector_type { return data_.d4x1_; } + else + { + return err; + } } template @@ -298,6 +311,10 @@ struct vector_type { return data_.d4x1_; } + else + { + return err; + } } }; @@ -347,6 +364,10 @@ struct vector_type { return data_.d8x1_; } + else + { + return err; + } } template @@ -372,6 +393,10 @@ struct vector_type { return data_.d8x1_; } + else + { + return err; + } } }; @@ -428,6 +453,10 @@ struct vector_type { return data_.d16x1_; } + else + { + return err; + } } template @@ -458,6 +487,10 @@ struct vector_type { return data_.d16x1_; } + else + { + return err; + } } }; @@ -520,6 +553,10 @@ struct vector_type { return data_.d32x1_; } + else + { + return err; + } } template @@ -554,6 +591,10 @@ struct vector_type { return data_.d32x1_; } + else + { + return err; + } } }; @@ -623,6 +664,10 @@ struct vector_type { return data_.d64x1_; } + else + { + return err; + } } template @@ -662,6 +707,10 @@ struct vector_type { return data_.d64x1_; } + else + { + return err; + } } }; @@ -737,6 +786,10 @@ struct vector_type { return data_.d128x1_; } + else + { + return err; + } } template @@ -780,6 +833,10 @@ struct vector_type { return data_.d128x1_; } + else + { + return err; + } } }; @@ -861,6 +918,10 @@ struct vector_type { return data_.d256x1_; } + else + { + return err; + } } template @@ -908,6 +969,10 @@ struct vector_type { return data_.d256x1_; } + else + { + return err; + } } }; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp index 45e35ec56d..5f2ab12164 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp @@ -265,6 +265,8 @@ struct ReferenceColumnToImage : public device::BaseOperator return 0; } + throw std::runtime_error("Col2Img: number of dimensions should be between 1 and 3."); + return 1; } float Run(const device::BaseArgument* p_arg, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp index 50040a2441..bfb8b48187 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp @@ -313,6 +313,9 @@ struct ReferenceConvBwdData : public device::BaseOperator return 0; } + throw std::runtime_error( + "Conv_bwd_data: number of dimensions must be between 1 and 3."); + return 1; } float Run(const device::BaseArgument* p_arg, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp index 02ad7a033a..d0b98efd1f 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp @@ -265,6 +265,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator return 0; } + throw std::runtime_error("Conv_bwd: number of dimensions must be between 1 and 3."); + return 1; } float Run(const device::BaseArgument* p_arg, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index ffc9470df2..d63b5256f9 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -360,6 +360,8 @@ struct ReferenceConvFwd : public device::BaseOperator return 0; } + throw std::runtime_error("Conv_fwd: number of dimensions must be between 1 and 3."); + return 1; } float Run(const device::BaseArgument* p_arg, diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 6e39dee71c..4d52563f42 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -63,12 +63,11 @@ struct ReferenceGemm : public device::BaseOperator const int K = arg.a_m_k_.mDesc.GetLengths()[1]; AccDataType v_acc = 0; + ComputeTypeA v_a = 0; + ComputeTypeB v_b = 0; for(int k = 0; k < K; ++k) { - ComputeTypeA v_a; - ComputeTypeB v_b; - // use PassThrough instead of ConvertBF16RTN for reference calculation if constexpr(is_same_v) @@ -94,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator ck::type_convert(v_a) * ck::type_convert(v_b); } - CDataType v_c; + CDataType v_c = 0; arg.c_element_op_(v_c, v_acc); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp index 750d4d14f8..4682c5c223 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp @@ -230,6 +230,8 @@ struct ReferenceImageToColumn : public device::BaseOperator return 0; } + throw std::runtime_error("Img2Col: number of dimensions should be between 1 and 3."); + return 1; } float Run(const device::BaseArgument* p_arg, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp index 77ad36b97b..42ca8e755d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_gemm.hpp @@ -106,9 +106,8 @@ struct DeviceOperationInstanceFactory< return op_ptrs; } }; - +#endif } // namespace instance } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp index 2df378b0c6..730785f702 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp @@ -114,9 +114,8 @@ struct DeviceOperationInstanceFactory Date: Thu, 25 Jan 2024 19:53:15 +0800 Subject: [PATCH 60/75] layernorm & groupnorm bwd gamma beta (#1133) * Add layernorm bwd gamma beta external api * Add groupnorm external api * Add layernorm bwd gamma beta profiler * Add groupnorm bwd gamma beta ckProfiler * Add layernorm & groupnorm bwd gamma beta test * Fix groupnorm bwd gamma beta profiler bug * Layernorm bwd weight client example * Groupnorm bwd weight client example * clang format * Remove useless header * Let inv_std be positive * Rename to num_bytes and move this calculation outside the loop --- client_example/05_layernorm/CMakeLists.txt | 3 + .../layernorm2d_bwd_gamma_beta.cpp | 171 ++++++++++++ client_example/18_groupnorm/CMakeLists.txt | 3 + .../18_groupnorm/groupnorm_bwd_gamma_beta.cpp | 180 ++++++++++++ .../gpu/groupnorm_bwd_gamma_beta.hpp | 64 +++++ .../gpu/layernorm_bwd_gamma_beta.hpp | 83 ++++++ ...ayernorm2d_bwd_gamma_beta_f16_instance.cpp | 2 +- ...ayernorm2d_bwd_gamma_beta_f32_instance.cpp | 2 +- .../profile_groupnorm_bwd_gamma_beta_impl.hpp | 261 +++++++++++++++++ .../profile_layernorm_bwd_gamma_beta_impl.hpp | 263 ++++++++++++++++++ profiler/src/CMakeLists.txt | 3 + .../src/profile_groupnorm_bwd_gamma_beta.cpp | 104 +++++++ .../src/profile_layernorm_bwd_gamma_beta.cpp | 112 ++++++++ test/CMakeLists.txt | 1 + .../CMakeLists.txt | 13 + .../test_groupnorm_bwd_gamma_beta_fp32.cpp | 51 ++++ .../test_layernorm2d_bwd_gamma_beta_fp32.cpp | 48 ++++ 17 files changed, 1362 insertions(+), 2 deletions(-) create mode 100644 client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp create mode 100644 client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp create mode 100644 profiler/include/profiler/profile_groupnorm_bwd_gamma_beta_impl.hpp create mode 100644 profiler/include/profiler/profile_layernorm_bwd_gamma_beta_impl.hpp create mode 100644 profiler/src/profile_groupnorm_bwd_gamma_beta.cpp create mode 100644 profiler/src/profile_layernorm_bwd_gamma_beta.cpp create mode 100644 test/normalization_bwd_gamma_beta/CMakeLists.txt create mode 100644 test/normalization_bwd_gamma_beta/test_groupnorm_bwd_gamma_beta_fp32.cpp create mode 100644 test/normalization_bwd_gamma_beta/test_layernorm2d_bwd_gamma_beta_fp32.cpp diff --git a/client_example/05_layernorm/CMakeLists.txt b/client_example/05_layernorm/CMakeLists.txt index 246f877cde..b7b3c830ed 100644 --- a/client_example/05_layernorm/CMakeLists.txt +++ b/client_example/05_layernorm/CMakeLists.txt @@ -1,6 +1,9 @@ add_executable(client_layernorm2d_bwd_data layernorm2d_bwd_data.cpp) target_link_libraries(client_layernorm2d_bwd_data PRIVATE composable_kernel::device_other_operations) +add_executable(client_layernorm2d_bwd_gamma_beta layernorm2d_bwd_gamma_beta.cpp) +target_link_libraries(client_layernorm2d_bwd_gamma_beta PRIVATE composable_kernel::device_other_operations) + add_executable(client_layernorm2d_fwd layernorm2d_fwd.cpp) target_link_libraries(client_layernorm2d_fwd PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp b/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp new file mode 100644 index 0000000000..98b394add6 --- /dev/null +++ b/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" + +#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DGammaDataType = float; +using DBetaDataType = float; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * M * N); + SimpleDeviceMem x_dev(sizeof(XDataType) * M * N); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem dgamma_dev(sizeof(DGammaDataType) * N); + SimpleDeviceMem dbeta_dev(sizeof(DBetaDataType) * N); + + using DeviceOp = + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::size_t num_bytes = sizeof(DYDataType) * M * N + sizeof(XDataType) * M * N + + sizeof(MeanInvStdDataType) * M * 2 + sizeof(DGammaDataType) * N + + sizeof(DBetaDataType) * N; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // inLengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N}, // outLengths + {1}, // dgammaStrides + {1}, // dbetaStrides + {0}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // inLengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N}, // outLengths + {1}, // dgammaStrides + {1}, // dbetaStrides + {0}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/18_groupnorm/CMakeLists.txt b/client_example/18_groupnorm/CMakeLists.txt index deb50f6fce..e04c26d8e7 100644 --- a/client_example/18_groupnorm/CMakeLists.txt +++ b/client_example/18_groupnorm/CMakeLists.txt @@ -1,5 +1,8 @@ add_executable(client_groupnorm_bwd_data groupnorm_bwd_data.cpp) target_link_libraries(client_groupnorm_bwd_data PRIVATE composable_kernel::device_other_operations) +add_executable(client_groupnorm_bwd_gamma_beta groupnorm_bwd_gamma_beta.cpp) +target_link_libraries(client_groupnorm_bwd_gamma_beta PRIVATE composable_kernel::device_other_operations) + add_executable(client_groupnorm_swish_fwd groupnorm_swish_fwd.cpp) target_link_libraries(client_groupnorm_swish_fwd PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp b/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp new file mode 100644 index 0000000000..c2fbe285df --- /dev/null +++ b/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" + +#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DGammaDataType = float; +using DBetaDataType = float; + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 32; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t G = 64; + ck::index_t C = 128; + + std::size_t length = N * H * W * G * C; + + std::vector strideDy = {H * W * G * C, W * G * C, G * C, C, 1}; + std::vector strideX = strideDy; + std::vector strideMeanInvStd = {G, 0, 0, 1, 0}; + std::vector strideDGammaBeta = {C, 1}; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * length); + SimpleDeviceMem x_dev(sizeof(XDataType) * length); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem dgamma_dev(sizeof(DGammaDataType) * G * C); + SimpleDeviceMem dbeta_dev(sizeof(DBetaDataType) * G * C); + + using DeviceOp = + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::size_t num_bytes = sizeof(DYDataType) * length + sizeof(XDataType) * length + + sizeof(GammaDataType) * G * C + sizeof(MeanInvStdDataType) * N * G * 2 + + sizeof(DGammaDataType) * G * C + sizeof(DBetaDataType) * G * C; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideMeanInvStd, + strideMeanInvStd, + {G, C}, + strideDGammaBeta, + strideDGammaBeta, + {0, 1, 2}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideMeanInvStd, + strideMeanInvStd, + {G, C}, + strideDGammaBeta, + strideDGammaBeta, + {0, 1, 2}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/library/include/ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp b/library/include/ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp new file mode 100644 index 0000000000..3f888d5c67 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +#ifdef CK_ENABLE_FP32 +// FP32 +void add_device_groupnorm_bwd_gamma_beta_f32_instances( + std::vector>>&); +#endif +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta> +{ + using DeviceOp = DeviceNormalizationBwdGammaBeta; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_groupnorm_bwd_gamma_beta_f32_instances(op_ptrs); + } +#endif + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp b/library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp new file mode 100644 index 0000000000..e2736ac77e --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +#ifdef CK_ENABLE_FP16 +// FP16 +void add_device_layernorm2d_bwd_gamma_beta_f16_instances( + std::vector>>&); +#endif +#ifdef CK_ENABLE_FP32 +// FP32 +void add_device_layernorm2d_bwd_gamma_beta_f32_instances( + std::vector>>&); +#endif +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta> +{ + using DeviceOp = DeviceNormalizationBwdGammaBeta; + + static auto GetInstances() + { + std::vector> op_ptrs; +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 2 && NumReduceDim == 1) + { + add_device_layernorm2d_bwd_gamma_beta_f16_instances(op_ptrs); + } + } +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(Rank == 2 && NumReduceDim == 1) + { + add_device_layernorm2d_bwd_gamma_beta_f32_instances(op_ptrs); + } + } +#endif + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp index aa399f56ec..160bcb4ace 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f16_instance.cpp @@ -8,7 +8,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f16_instances( +void add_device_layernorm2d_bwd_gamma_beta_f16_instances( std::vector>>& instances) { diff --git a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp index ba2966ba37..6f42eca0b9 100644 --- a/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization_bwd_gamma_beta/device_layernorm2d_bwd_gamma_beta_f32_instance.cpp @@ -8,7 +8,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_layernorm2d_bwd_gamma_beta_rank_2_1_f32_instances( +void add_device_layernorm2d_bwd_gamma_beta_f32_instances( std::vector>>& instances) { diff --git a/profiler/include/profiler/profile_groupnorm_bwd_gamma_beta_impl.hpp b/profiler/include/profiler/profile_groupnorm_bwd_gamma_beta_impl.hpp new file mode 100644 index 0000000000..5e9d3df1b1 --- /dev/null +++ b/profiler/include/profiler/profile_groupnorm_bwd_gamma_beta_impl.hpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_groupnorm_bwd_gamma_beta_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector length) +{ + // we don't need GammaDataType and DXDataType here, just for reference class + using GammaDataType = DYDataType; + using DXDataType = DYDataType; + + if(length.size() != 5) + return false; + + index_t N = length[0]; + index_t G = length[3]; + index_t C = length[4]; + + std::vector reduce_dim = {0, 1, 2}; + std::vector gamma_beta_length = {G, C}; + + Tensor dy(length); + Tensor x(length); + Tensor gamma(gamma_beta_length); // dummy tensor, for reference + Tensor mean({N, G}); + Tensor inv_std({N, G}); + Tensor dgamma(gamma_beta_length); + Tensor dbeta(gamma_beta_length); + + Tensor host_dx(length); // dummy tensor, for reference + Tensor host_dgamma(gamma_beta_length); + Tensor host_dbeta(gamma_beta_length); + + std::vector strideDy = + std::vector{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; + std::vector strideX = + std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}; + + std::vector strideDGamma{dgamma.mDesc.GetStrides().begin(), + dgamma.mDesc.GetStrides().end()}; + + std::vector strideDBeta{dbeta.mDesc.GetStrides().begin(), + dbeta.mDesc.GetStrides().end()}; + + std::vector strideMeanInvStd = {G, 0, 0, 1, 0}; + + switch(init_method) + { + case 0: + dy.GenerateTensorValue(GeneratorTensor_1{}); + x.GenerateTensorValue(GeneratorTensor_1{}); + mean.GenerateTensorValue(GeneratorTensor_1{}); + inv_std.GenerateTensorValue(GeneratorTensor_1{}); + dgamma.GenerateTensorValue(GeneratorTensor_1{}); + dbeta.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 1: + dy.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + mean.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + inv_std.GenerateTensorValue(GeneratorTensor_2{0, 5}); + dgamma.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + dbeta.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + dy.GenerateTensorValue(GeneratorTensor_3{0, 1}); + x.GenerateTensorValue(GeneratorTensor_3{0, 1}); + mean.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + inv_std.GenerateTensorValue(GeneratorTensor_3{0, 0.5}); + dgamma.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + dbeta.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); + DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); + DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); + + dy_dev.ToDevice(dy.mData.data()); + x_dev.ToDevice(x.mData.data()); + mean_dev.ToDevice(mean.mData.data()); + inv_std_dev.ToDevice(inv_std.mData.data()); + + // add device normalization instances + using DeviceOp = + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceInstance = + ck::tensor_operation::host::ReferenceGroupnormBwd; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, length); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + } + + std::size_t num_bytes = dy.mDesc.GetElementSize() * sizeof(DYDataType) + + x.mDesc.GetElementSize() * sizeof(XDataType) + + mean.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + inv_std.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + dgamma.mDesc.GetElementSize() * sizeof(DGammaDataType) + + dbeta.mDesc.GetElementSize() * sizeof(DBetaDataType); + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer(length, + strideDy, + strideX, + strideMeanInvStd, + strideMeanInvStd, + gamma_beta_length, + strideDGamma, + strideDBeta, + reduce_dim, + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = ", length, ", ") << std::endl; + } + + continue; + } + + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + dgamma_dev.FromDevice(dgamma.mData.data()); + dbeta_dev.FromDevice(dbeta.mData.data()); + bool pass = + ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); + + pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); + + if(do_log) + { + LogRangeAsType(std::cout << "dy : ", dy.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_dgamma : ", host_dgamma.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "dgamma : ", dgamma.mData, ",") << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", length, ",") << ", "; + LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s," + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_layernorm_bwd_gamma_beta_impl.hpp b/profiler/include/profiler/profile_layernorm_bwd_gamma_beta_impl.hpp new file mode 100644 index 0000000000..10fa9c86d5 --- /dev/null +++ b/profiler/include/profiler/profile_layernorm_bwd_gamma_beta_impl.hpp @@ -0,0 +1,263 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_layernorm_bwd_gamma_beta_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector length) +{ + // we don't need GammaDataType and DXDataType here, just for reference class + using GammaDataType = DYDataType; + using DXDataType = DYDataType; + + if(length.size() != Rank || Rank < 2) + return false; + + // Assume normalize dimension for first dimension + // Layernorm 2D, input = [M, K], reduce on M axis + // Layernorm 4D, input = [N, H, W, C], redice on N axis + constexpr int NumReduceDim = Rank - 1; + + std::vector reduce_dim = {0}; + std::vector invarient_length{length.begin() + 1, length.end()}; + + Tensor dy(length); + Tensor x(length); + Tensor gamma(invarient_length); // dummy tensor, for reference + Tensor mean({length[0]}); + Tensor inv_std({length[0]}); + Tensor dgamma(invarient_length); + Tensor dbeta(invarient_length); + + Tensor host_dx(length); // dummy tensor, for reference + Tensor host_dgamma(invarient_length); + Tensor host_dbeta(invarient_length); + + std::vector strideDy = + std::vector{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; + std::vector strideX = strideDy; + + std::vector strideDGamma{dgamma.mDesc.GetStrides().begin(), + dgamma.mDesc.GetStrides().end()}; + + std::vector strideDBeta{dbeta.mDesc.GetStrides().begin(), + dbeta.mDesc.GetStrides().end()}; + + std::vector strideMeanInvStd{Rank, 0}; + strideMeanInvStd[0] = 1; + + switch(init_method) + { + case 0: + dy.GenerateTensorValue(GeneratorTensor_1{}); + x.GenerateTensorValue(GeneratorTensor_1{}); + mean.GenerateTensorValue(GeneratorTensor_1{}); + inv_std.GenerateTensorValue(GeneratorTensor_1{}); + dgamma.GenerateTensorValue(GeneratorTensor_1{}); + dbeta.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 1: + dy.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + mean.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + inv_std.GenerateTensorValue(GeneratorTensor_2{0, 5}); + dgamma.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + dbeta.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + dy.GenerateTensorValue(GeneratorTensor_3{0, 1}); + x.GenerateTensorValue(GeneratorTensor_3{0, 1}); + mean.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + inv_std.GenerateTensorValue(GeneratorTensor_3{0, 0.5}); + dgamma.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + dbeta.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); + DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); + DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); + + dy_dev.ToDevice(dy.mData.data()); + x_dev.ToDevice(x.mData.data()); + mean_dev.ToDevice(mean.mData.data()); + inv_std_dev.ToDevice(inv_std.mData.data()); + + // add device normalization instances + using DeviceOp = + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta; + + // get device op instances + const auto instance_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << instance_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + if(do_verification) + { + using ReferenceInstance = + ck::tensor_operation::host::ReferenceLayernormBwd; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, length); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + } + + std::size_t num_bytes = dy.mDesc.GetElementSize() * sizeof(DYDataType) + + x.mDesc.GetElementSize() * sizeof(XDataType) + + mean.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + inv_std.mDesc.GetElementSize() * sizeof(MeanInvStdDataType) + + dgamma.mDesc.GetElementSize() * sizeof(DGammaDataType) + + dbeta.mDesc.GetElementSize() * sizeof(DBetaDataType); + + int num_kernel = 0; + + for(auto& inst_ptr : instance_ptrs) + { + auto argument_ptr = inst_ptr->MakeArgumentPointer(length, + strideDy, + strideX, + strideMeanInvStd, + strideMeanInvStd, + invarient_length, + strideDGamma, + strideDBeta, + reduce_dim, + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + if(inst_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + } + else + { + if(time_kernel) + { + std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; + LogRange(std::cout << "input lengths = ", length, ", ") << std::endl; + } + + continue; + } + + size_t workspace_sz = inst_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + inst_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = inst_ptr->MakeInvokerPointer(); + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << inst_ptr->GetTypeString() << std::endl; + + if(avg_time < best_avg_time) + { + best_instance_name = inst_ptr->GetTypeString(); + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + dgamma_dev.FromDevice(dgamma.mData.data()); + dbeta_dev.FromDevice(dbeta.mData.data()); + bool pass = + ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); + + pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); + + if(do_log) + { + LogRangeAsType(std::cout << "dy : ", dy.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_dgamma : ", host_dgamma.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "dgamma : ", dgamma.mData, ",") << std::endl; + } + + if(!pass) + { + std::cout << inst_ptr->GetTypeString() << " failed verification: "; + LogRange(std::cout << "lengths = [", length, ", ") << "]." << std::endl; + return false; + } + else + { + if(time_kernel) + std::cout << "pass" << std::endl; + } + } + } + + if(time_kernel) + { + LogRange(std::cout << "length = ", length, ",") << ", "; + LogRange(std::cout << "reduce dims ", reduce_dim, ",") << std::endl; + std::cout << "best perf = " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s," + << best_instance_name << std::endl; + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + return false; + } + + return true; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 68ef04ed11..e9cf6eecfb 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -19,6 +19,8 @@ set(PROFILER_SOURCES profile_groupnorm_bwd_data.cpp profile_groupnorm_fwd.cpp profile_layernorm_bwd_data.cpp + profile_layernorm_bwd_gamma_beta.cpp + profile_groupnorm_bwd_gamma_beta.cpp profile_layernorm_fwd.cpp profile_max_pool3d_fwd.cpp profile_avg_pool3d_bwd.cpp @@ -82,6 +84,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) diff --git a/profiler/src/profile_groupnorm_bwd_gamma_beta.cpp b/profiler/src/profile_groupnorm_bwd_gamma_beta.cpp new file mode 100644 index 0000000000..7fcef3a4e2 --- /dev/null +++ b/profiler/src/profile_groupnorm_bwd_gamma_beta.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/data_type_enum.hpp" +#include "profiler/profile_groupnorm_bwd_gamma_beta_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +struct groupnormBwdGammaBetaArgParser +{ + std::unordered_map> long_opts = {{"length", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help_groupnorm_bwd_gamma_beta() +{ + // eg: ckProfiler groupnorm_bwd_gamma_beta 1 0 2 0 1 --length 1 16 16 32 40 + std::cout << "arg1: data type (0: fp16; 1: fp32)\n" + << "arg2: verification (0: no; 1: yes)\n" + << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg4: print tensor value (0: no; 1: yes)\n" + << "arg5: time kernel (0=no, 1=yes)\n" + << "--length: tensor extents (e.g, --length 1 16 16 32 40) \n" + << std::endl; +} + +int profile_groupnorm_bwd_gamma_beta(int argc, char* argv[]) +{ + if(argc <= 2) + { + print_help_groupnorm_bwd_gamma_beta(); + return 0; + } + + groupnormBwdGammaBetaArgParser arg_parser; + + // short unnamed options + const ck::DataTypeEnum data_type = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + + // parse the long options + arg_parser(argc, argv); + const std::vector length = arg_parser.long_opts["length"]; + + using F32 = float; + + if(length.size() == 5) + { + if(data_type == ck::DataTypeEnum::Float) + { + ck::profiler::profile_groupnorm_bwd_gamma_beta_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else + { + throw std::runtime_error("not implemented yet"); + } + } + else + { + throw std::runtime_error("length should be 5"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("groupnorm_bwd_gamma_beta", + "Group Normalization", + profile_groupnorm_bwd_gamma_beta); diff --git a/profiler/src/profile_layernorm_bwd_gamma_beta.cpp b/profiler/src/profile_layernorm_bwd_gamma_beta.cpp new file mode 100644 index 0000000000..0f3436c663 --- /dev/null +++ b/profiler/src/profile_layernorm_bwd_gamma_beta.cpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "profiler/data_type_enum.hpp" +#include "profiler/profile_layernorm_bwd_gamma_beta_impl.hpp" +#include "profiler_operation_registry.hpp" + +using ck::index_t; + +struct layernormBwdGammaBetaArgParser +{ + std::unordered_map> long_opts = {{"length", {}}}; + + bool parse_opt(int argc, char* argv[], const std::string& key, int i) + { + if(std::string("--") + key == argv[i]) + { + int pos = i; + while(++i < argc && argv[i][0] != '-') {} + int end = i; + for(int j = pos + 1; j < end; j++) + { + long_opts[key].push_back(std::stoi(argv[j])); + } + return true; + } + return false; + } + + void operator()(int argc, char* argv[]) + { + for(auto& kv : long_opts) + { + for(int i = 1; i < argc; i++) + { + if(parse_opt(argc, argv, kv.first, i)) + break; + } + } + } +}; + +void print_help_layernorm_bwd_gamma_beta() +{ + // eg: ckProfiler layernorm_bwd_gamma_beta 0 0 2 0 1 --length 1502 4096 + std::cout << "arg1: data type (0: fp16; 1: fp32)\n" + << "arg2: verification (0: no; 1: yes)\n" + << "arg3: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg4: print tensor value (0: no; 1: yes)\n" + << "arg5: time kernel (0=no, 1=yes)\n" + << "--length: tensor extents (e.g, --length 1024 1024) \n" + << std::endl; +} + +int profile_layernorm_bwd_gamma_beta(int argc, char* argv[]) +{ + if(argc <= 2) + { + print_help_layernorm_bwd_gamma_beta(); + return 0; + } + + layernormBwdGammaBetaArgParser arg_parser; + + // short unnamed options + const ck::DataTypeEnum data_type = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + + // parse the long options + arg_parser(argc, argv); + const std::vector length = arg_parser.long_opts["length"]; + + using F16 = ck::half_t; + using F32 = float; + + if(length.size() == 2) + { + constexpr int rank = 2; + + if(data_type == ck::DataTypeEnum::Half) + { + ck::profiler::profile_layernorm_bwd_gamma_beta_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else if(data_type == ck::DataTypeEnum::Float) + { + ck::profiler::profile_layernorm_bwd_gamma_beta_impl( + do_verification, init_method, do_log, time_kernel, length); + } + else + { + throw std::runtime_error("not implemented yet"); + } + } + else + { + throw std::runtime_error("not implemented yet"); + } + + return 0; +} + +REGISTER_PROFILER_OPERATION("layernorm_bwd_gamma_beta", + "Layer Normalization", + profile_layernorm_bwd_gamma_beta); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 90140659f6..fa5f8583af 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -140,6 +140,7 @@ add_subdirectory(block_to_ctile_map) add_subdirectory(softmax) add_subdirectory(normalization_fwd) add_subdirectory(normalization_bwd_data) +add_subdirectory(normalization_bwd_gamma_beta) add_subdirectory(data_type) add_subdirectory(elementwise_normalization) add_subdirectory(batchnorm) diff --git a/test/normalization_bwd_gamma_beta/CMakeLists.txt b/test/normalization_bwd_gamma_beta/CMakeLists.txt new file mode 100644 index 0000000000..f3579aad08 --- /dev/null +++ b/test/normalization_bwd_gamma_beta/CMakeLists.txt @@ -0,0 +1,13 @@ +add_custom_target(test_normalization_bwd_gamma_beta) +add_gtest_executable(test_layernorm2d_bwd_gamma_beta_fp32 test_layernorm2d_bwd_gamma_beta_fp32.cpp) +if(result EQUAL 0) + target_link_libraries(test_layernorm2d_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) + add_dependencies(test_normalization_bwd_gamma_beta test_layernorm2d_bwd_gamma_beta_fp32) +endif() + +add_gtest_executable(test_groupnorm_bwd_gamma_beta_fp32 test_groupnorm_bwd_gamma_beta_fp32.cpp) +if(result EQUAL 0) + target_link_libraries(test_groupnorm_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) + add_dependencies(test_normalization_bwd_gamma_beta test_groupnorm_bwd_gamma_beta_fp32) +endif() + diff --git a/test/normalization_bwd_gamma_beta/test_groupnorm_bwd_gamma_beta_fp32.cpp b/test/normalization_bwd_gamma_beta/test_groupnorm_bwd_gamma_beta_fp32.cpp new file mode 100644 index 0000000000..ab9cb29891 --- /dev/null +++ b/test/normalization_bwd_gamma_beta/test_groupnorm_bwd_gamma_beta_fp32.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_groupnorm_bwd_gamma_beta_impl.hpp" + +using F16 = ck::half_t; +using F32 = float; +using ck::index_t; + +template +class TestgroupnormBwdGammaBeta : public ::testing::Test +{ + protected: + using DYDataType = std::tuple_element_t<0, Tuple>; + using XDataType = std::tuple_element_t<1, Tuple>; + using MeanInvStdDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using DGammaDataType = std::tuple_element_t<4, Tuple>; + using DBetaDataType = std::tuple_element_t<5, Tuple>; + + void Run() + { + // Bwd data: [N, H, W, G, C], reduce H, W, C + std::vector> lengths = {{1, 1, 1, 1, 1}, + {1, 2, 3, 4, 5}, + {256, 9, 9, 9, 9}, + {1, 64, 64, 32, 10}, + {1, 32, 32, 32, 20}, + {1, 16, 16, 32, 40}}; + + for(auto length : lengths) + { + bool success = ck::profiler::profile_groupnorm_bwd_gamma_beta_impl( + true, 2, false, false, length); + EXPECT_TRUE(success); + } + } +}; + +using KernelTypes = ::testing::Types< + // DYDataType XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType> + std::tuple>; + +TYPED_TEST_SUITE(TestgroupnormBwdGammaBeta, KernelTypes); +TYPED_TEST(TestgroupnormBwdGammaBeta, Test_FP32) { this->Run(); } diff --git a/test/normalization_bwd_gamma_beta/test_layernorm2d_bwd_gamma_beta_fp32.cpp b/test/normalization_bwd_gamma_beta/test_layernorm2d_bwd_gamma_beta_fp32.cpp new file mode 100644 index 0000000000..53c92413b1 --- /dev/null +++ b/test/normalization_bwd_gamma_beta/test_layernorm2d_bwd_gamma_beta_fp32.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_layernorm_bwd_gamma_beta_impl.hpp" + +using F16 = ck::half_t; +using F32 = float; +using ck::index_t; + +template +class TestLayernorm2dBwdGammaBeta : public ::testing::Test +{ + protected: + using DYDataType = std::tuple_element_t<0, Tuple>; + using XDataType = std::tuple_element_t<1, Tuple>; + using MeanInvStdDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using DGammaDataType = std::tuple_element_t<4, Tuple>; + using DBetaDataType = std::tuple_element_t<5, Tuple>; + + void Run() + { + // Bwd data: [N, D], reduce D + std::vector> lengths = { + {4, 256}, {8, 511}, {9, 1032}, {4, 2048}, {1, 8192}, {4000, 2000}}; + + for(auto length : lengths) + { + bool success = ck::profiler::profile_layernorm_bwd_gamma_beta_impl( + true, 2, false, false, length); + EXPECT_TRUE(success); + } + } +}; + +using KernelTypes = ::testing::Types< + // DYDataType XDataType, MeanInvStdDataType, ComputeDataType, DGammaDataType, DBetaDataType> + std::tuple>; + +TYPED_TEST_SUITE(TestLayernorm2dBwdGammaBeta, KernelTypes); +TYPED_TEST(TestLayernorm2dBwdGammaBeta, Test_FP32) { this->Run(); } From 4a8297c28ada39c3e8bbcf98bf3addb5f733ae94 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 25 Jan 2024 17:05:43 -0800 Subject: [PATCH 61/75] fix CK path for hipTensor (#1143) --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index 9359aa1f69..8ab326da7a 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -560,7 +560,7 @@ def Build_CK(Map conf=[:]){ sh """#!/bin/bash mkdir -p build ls -ltr - CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="/opt/rocm;${env.WORKSPACE}/install" + CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install" cmake --build build -- -j """ } From 84832fc42d71e446fa2ddbf88b96fc2c05b21b49 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Jan 2024 09:02:52 -0800 Subject: [PATCH 62/75] Bump rocm-docs-core from 0.31.0 to 0.33.0 in /docs/sphinx (#1144) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.31.0 to 0.33.0. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.31.0...v0.33.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 23a4c4bb91..88142aa373 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.31.0 +rocm-docs-core==0.33.0 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 1e5e688dac..12414c7470 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -113,7 +113,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.31.0 +rocm-docs-core==0.33.0 # via -r requirements.in six==1.16.0 # via From e7495e6bb7d0a112c30e51249fb6b537b9ab96ce Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 30 Jan 2024 13:14:58 -0800 Subject: [PATCH 63/75] turn off performance tests in CI by default until the infrastructure is fixed (#1147) --- Jenkinsfile | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 8ab326da7a..c1363b5d2a 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -713,6 +713,10 @@ pipeline { name: "RUN_CPPCHECK", defaultValue: false, description: "Run the cppcheck static analysis (default: OFF)") + booleanParam( + name: "RUN_PERFORMANCE_TESTS", + defaultValue: false, + description: "Run the performance tests (default: OFF)") } environment{ dbuser = "${dbuser}" @@ -890,7 +894,7 @@ pipeline { { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() } + expression { !params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() } } options { retry(2) } agent{ label rocmnode("gfx908 || gfx90a")} @@ -906,7 +910,7 @@ pipeline { { when { beforeAgent true - expression { params.RUN_FULL_QA.toBoolean() } + expression { params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() } } options { retry(2) } agent{ label rocmnode("gfx90a")} @@ -925,6 +929,10 @@ pipeline { parallel { stage("Process results"){ + when { + beforeAgent true + expression { params.RUN_PERFORMANCE_TESTS.toBoolean() } + } agent { label 'mici' } steps{ process_results() From 6651a124cc4c467daf1a8ceaededc51bf0b7f656 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 30 Jan 2024 13:55:31 -0800 Subject: [PATCH 64/75] update the name of the compiler staging branch (#1148) --- Dockerfile | 4 ++-- Jenkinsfile | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Dockerfile b/Dockerfile index a805285a77..48ee97eec2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -122,7 +122,7 @@ ENV compiler_commit=$compiler_commit RUN sh -c "echo compiler version = '$compiler_version'" RUN sh -c "echo compiler commit = '$compiler_commit'" -RUN if ( [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" = "" ]; then \ +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" = "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ @@ -130,7 +130,7 @@ RUN if ( [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "am else echo "using the release compiler"; \ fi -RUN if ( [ "$compiler_version" = "amd-stg-open" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" != "" ]; then \ +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" != "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ diff --git a/Jenkinsfile b/Jenkinsfile index c1363b5d2a..80e7b044f1 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -84,7 +84,7 @@ def build_compiler(){ compiler = '/opt/rocm/bin/hipcc' } else{ - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ compiler = "/llvm-project/build/bin/clang++" } else{ @@ -293,7 +293,7 @@ def buildHipClangJob(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -348,7 +348,7 @@ def runCKProfiler(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -479,7 +479,7 @@ def Build_CK(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -657,7 +657,7 @@ def process_results(Map conf=[:]){ //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.0;COMPILER_VERSION= 0 21 * * * % ROCMVERSION=6.0;COMPILER_VERSION=;COMPILER_COMMIT= - 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=;USE_SCCACHE=false + 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;COMPILER_COMMIT=;USE_SCCACHE=false 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false''' : "" pipeline { @@ -680,7 +680,7 @@ pipeline { string( name: 'COMPILER_VERSION', defaultValue: '', - description: 'Specify which version of compiler to use: release, amd-stg-open, amd-mainline-open, or leave blank (default).') + description: 'Specify which version of compiler to use: release, amd-staging, amd-mainline-open, or leave blank (default).') string( name: 'COMPILER_COMMIT', defaultValue: '', From f3b6c23ac59a6bab5d45aeee2320e1ea91f33178 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 31 Jan 2024 21:24:40 +0100 Subject: [PATCH 65/75] Add blockwise gemm to ck wrapper (#1139) * Add blockwise gemm to ck wrapper * Add blockwise gemm traits * Disable test_gemm for non xdl devices * Fixes * Add c layout descritpions --- CHANGELOG.md | 2 +- docs/wrapper.rst | 1 + include/ck/wrapper/layout.hpp | 3 + include/ck/wrapper/operations/copy.hpp | 125 +++++-- include/ck/wrapper/operations/gemm.hpp | 337 ++++++++++++++++++ include/ck/wrapper/tensor.hpp | 39 +- .../traits/blockwise_gemm_xdl_traits.hpp | 48 +++ include/ck/wrapper/utils/tensor_partition.hpp | 295 ++++++++++++--- include/ck/wrapper/utils/tensor_utils.hpp | 35 +- test/wrapper/CMakeLists.txt | 6 + test/wrapper/test_gemm.cpp | 257 +++++++++++++ test/wrapper/test_partition.cpp | 32 +- 12 files changed, 1064 insertions(+), 116 deletions(-) create mode 100644 include/ck/wrapper/operations/gemm.hpp create mode 100644 include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp create mode 100644 test/wrapper/test_gemm.cpp diff --git a/CHANGELOG.md b/CHANGELOG.md index c721039523..4e3feed2df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ None None ### Additions -* Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126) +* Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126, #1139) ### Changes None diff --git a/docs/wrapper.rst b/docs/wrapper.rst index 79b6c75580..c64c0bf17f 100644 --- a/docs/wrapper.rst +++ b/docs/wrapper.rst @@ -89,3 +89,4 @@ Operations ------------------------------------- .. doxygenfile:: copy.hpp +.. doxygenfile:: gemm.hpp diff --git a/include/ck/wrapper/layout.hpp b/include/ck/wrapper/layout.hpp index 39b5c79c67..71c512e136 100644 --- a/include/ck/wrapper/layout.hpp +++ b/include/ck/wrapper/layout.hpp @@ -248,6 +248,9 @@ struct Layout using DefaultIdxsTupleType = remove_cvref_t; public: + using LayoutShape = Shape; + using LayoutUnrolledDescriptorType = UnrolledDescriptorType; + /** * \brief Transform descriptor to align to passed indexes. * diff --git a/include/ck/wrapper/operations/copy.hpp b/include/ck/wrapper/operations/copy.hpp index 7b00fe5500..614dfd758e 100644 --- a/include/ck/wrapper/operations/copy.hpp +++ b/include/ck/wrapper/operations/copy.hpp @@ -3,45 +3,18 @@ #pragma once -#include "../utils/tensor_utils.hpp" +#include "ck/wrapper/utils/tensor_utils.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" namespace ck { namespace wrapper { -/** - * \brief Perform generic copy between two tensors partitions (threadwise copy). - * Tensors must have the same size. - * - * \param src_tensor Source tensor. - * \param dst_tensor Destination tensor. - */ -template -__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) -{ - if constexpr(!SrcTensorType::IsDynamicBuffer) - { - using SizeType = decltype(size(src_tensor)); - static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); }); - } - else if constexpr(!DstTensorType::IsDynamicBuffer) - { - using SizeType = decltype(size(dst_tensor)); - static_for<0, SizeType{}, 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); }); - } - else - { - for(int i = 0; i < size(src_tensor); i++) - { - dst_tensor(i) = src_tensor(i); - } - } -} - /** * \brief Perform optimized copy between two tensors partitions (threadwise copy). * Tensors must have the same size. @@ -167,9 +140,99 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) else { // Perform copy between StaticBuffers - copy(src_tensor, dst_tensor); + static_for<0, SrcShapeType::Size(), 1>{}([&](auto i) { dst_tensor(i) = src_tensor(i); }); } } +/** + * \brief Perform generic copy between two tensors partitions (threadwise copy). + * Tensors must have the same size. + * + * \param src_tensor Source tensor. + * \param dst_tensor Destination tensor. + */ +template +__host__ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) +{ + // Generate default params + using SrcShapeType = remove_cvref_t; + constexpr index_t num_dims = SrcShapeType::Size(); + // Incrementing dims 0, 1, 2 ... num_dims - 1 + constexpr auto dim_access_order_tuple = + generate_tuple([](auto i) { return Number{}; }, Number{}); + constexpr index_t vector_dim = num_dims - 1; + constexpr index_t scalar_per_vector = 1; + copy(src_tensor, dst_tensor); +} + +/** + * \brief Perform optimized blockwise copy between two tensors. Tensors must have the + * same size. + * + * \note At now Vgpr and Sgpr are not supported. + * + * \tparam DimAccessOrderTuple Tuple with dimension access order. + * \tparam VectorDim Dimension for vectorize read and write. + * \tparam ScalarPerVector Number of scalar per vectorize read and write. + * \param src_tensor Source tensor. + * \param dst_tensor Destination tensor. + * \param thread_layout Thread layout per each dimension for copy. + */ +template +__device__ void blockwise_copy(const SrcTensorType& src_tensor, + DstTensorType& dst_tensor, + [[maybe_unused]] ThreadLayoutTuple& thread_layout) +{ + static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer); + static_assert(is_detected::value); + + const auto& in_grid_desc = layout(src_tensor).GetUnrolledDescriptor(); + const auto& out_grid_desc = layout(dst_tensor).GetUnrolledDescriptor(); + + using SrcShapeType = remove_cvref_t; + constexpr index_t num_dims = SrcShapeType::Size(); + + constexpr auto tile_lengths_seq = + generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number{}); + constexpr auto thread_layout_seq = generate_sequence_v2( + [](auto I) { return size(ThreadLayoutTuple{}.At(I)); }, Number{}); + constexpr auto dim_access_order = generate_sequence_v2( + [](auto I) { return DimAccessOrderTuple{}.At(I); }, Number{}); + + using ThisThreadBlock = ThisThreadBlock; + + // Perform copy between DynamicBuffers + auto transfer = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + Tuple, + Tuple, + decltype(tie(in_grid_desc)), + decltype(tie(out_grid_desc)), + tensor_operation::element_wise::PassThrough, + Sequence(InMemoryDataOperationEnum::Set)>, + std::remove_const_t, + std::remove_const_t, + std::remove_const_t, + std::remove_const_t, + VectorDim, + ScalarPerVector, + Sequence, + Sequence>{in_grid_desc, + make_tuple(src_tensor.GetMultiIdxOffsets()), + out_grid_desc, + make_tuple(dst_tensor.GetMultiIdxOffsets()), + tensor_operation::element_wise::PassThrough{}}; + + transfer.Run(tie(in_grid_desc), + tie(src_tensor.GetBuffer()), + tie(out_grid_desc), + tie(dst_tensor.GetBuffer())); +} + } // namespace wrapper } // namespace ck diff --git a/include/ck/wrapper/operations/gemm.hpp b/include/ck/wrapper/operations/gemm.hpp new file mode 100644 index 0000000000..9b8c0543fd --- /dev/null +++ b/include/ck/wrapper/operations/gemm.hpp @@ -0,0 +1,337 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/wrapper/utils/tensor_utils.hpp" +#include "ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" + +namespace ck { +namespace wrapper { + +namespace { +namespace detail { +/** + * \brief Create block descriptor (K0, MPerBlock or NPerBlock, K1). + * + * + * \tparam K1 The number of K-dim elements that are packed together as a separate logical dimension. + * \tparam TileLayout Tensor data tile layout (M,K) or (N,K). + * + * \return Block descriptor (K0, MPerBlock or NPerBlock, K1) + */ +template +__device__ constexpr auto GetBlockDescriptor() +{ + using TileLayoutShape = typename TileLayout::LayoutShape; + using TileLayoutDescriptor = typename TileLayout::LayoutUnrolledDescriptorType; + + constexpr auto K0PerBlock = Number(TileLayoutShape{})>{} / Number{}; + // MPerBlock or NPerBlock + constexpr auto Dim0 = Number(TileLayoutShape{})>{}; + + constexpr auto a_block_desc_k0_m_k1 = transform_tensor_descriptor( + TileLayoutDescriptor{}, + make_tuple(make_unmerge_transform(make_tuple(K0PerBlock, Number{})), + make_pass_through_transform(Dim0)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_block_desc_k0_m_k1; +} + +} // namespace detail +} // namespace + +/** + * \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be + * stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B + * data layout must be (NPerBlock, KPerBlock). + * + * \note C output Vgpr register layout (8D): + * - MXdlPerWave - The number of MFMA instructions run by single wave in M + * dimension per tile. + * - NXdlPerWave - The number of MFMA instructions run by single wave in N + * dimension per tile. + * - MWave - Equals to 1 since this is for single wave. + * - NWave - Equals to 1 since this is for single wave. + * - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * - NumInputsBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * - GroupSize - Mfma instruction internal layout (depeneds on the + * instruction size). + * - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * + * \tparam DataType Input data types. + * \tparam BlockSize Tensor to pad. + * \tparam GemmTraits Traits of gemm xdl operation. + * \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm + * (MPerBlock, KPerBlock) layout. + * \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm + * (NPerBlock, KPerBlock) layout. + * \param c_reg_tensor C tensor VGPR memory for blockwise gemm. + */ +template +__device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor, + const BTensorType& b_local_tile_tensor, + CTensorType& c_reg_tensor) +{ + static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds); + static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds); + static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr); + static_assert(is_same_v); + static_assert(is_same_v); + + constexpr bool is_integer = + is_same_v || is_same_v || is_same_v; + using GemmAccDataType = std::conditional_t; + + using ATileLayout = remove_cvref_t; + using BTileLayout = remove_cvref_t; + + using ABlockDesc_K0_M_K1_Type = + decltype(detail::GetBlockDescriptor()); + using BBlockDesc_K0_N_K1_Type = + decltype(detail::GetBlockDescriptor()); + + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + blockwise_gemm_xdl_op{}; + + blockwise_gemm_xdl_op.Run( + a_local_tile_tensor.GetBuffer(), b_local_tile_tensor.GetBuffer(), c_reg_tensor.GetBuffer()); +} + +/** + * \brief Create local partition per thread for C tensor. + * + * \note C output global memory layout (8D): + * - MXdlPerWave - The number of MFMA instructions run by single wave in M + * dimension. + * - NXdlPerWave - The number of MFMA instructions run by single wave in N + * dimension. + * - MWave - The number of waves in single tile M dimension per tile. + * - NWave - The number of waves in single tile N dimension per tile. + * - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * - NumInputsBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * - GroupSize - Mfma instruction internal layout (depeneds on the + * instruction size). + * - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * + * \tparam DataType Input data types. + * \tparam ATileLayout A tensor layout. + * \tparam BTileLayout B tensor layout. + * \tparam BlockSize Number of threads in block. + * \tparam GemmTraits Traits of gemm xdl operation. + * \param c_local_tile_tensor C tensor in LDS memory for blockwise gemm + * (MPerBlock, NPerBlock) layout. + * + * \return Partition c tensor for blockwise gemm. + */ +template +__host__ __device__ constexpr auto +make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + + constexpr bool is_integer = + is_same_v || is_same_v || is_same_v; + using GemmAccDataType = std::conditional_t; + + using ABlockDesc_K0_M_K1_Type = + decltype(detail::GetBlockDescriptor()); + using BBlockDesc_K0_N_K1_Type = + decltype(detail::GetBlockDescriptor()); + + using BlockwiseGemmXdlops = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + BlockwiseGemmXdlops::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); + + // Calculate offset on grid + const auto c_thread_mtx_on_block = + BlockwiseGemmXdlops::CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + c_local_tile_tensor.GetMultiIdxOffsets()[I0] + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + c_local_tile_tensor.GetMultiIdxOffsets()[I1] + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + // Create partition shape based on descriptor dims. + const auto partition_shape = make_tuple(M0, N0, I1, I1, M2, I1, M4, I1); + + const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2( + layout(c_local_tile_tensor).GetUnrolledDescriptor()); + const auto partition_layout = + Layout, decltype(partition_desc)>( + partition_shape, partition_desc); + auto partition_tensor = make_tensor( + c_local_tile_tensor.GetPointer(), partition_layout); + partition_tensor.SetMultiIdxOffset(make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2])); + return partition_tensor; +} + +/** + * \brief Create local partition per thread for C tensor. + * + * \note C output Vgpr register layout (8D): + * - MXdlPerWave - The number of MFMA instructions run by single wave in M + * dimension per tile. + * - NXdlPerWave - The number of MFMA instructions run by single wave in N + * dimension per tile. + * - MWave - Equals to 1 since this is for single wave. + * - NWave - Equals to 1 since this is for single wave. + * - NumGroupsPerBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * - NumInputsBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * - GroupSize - Mfma instruction internal layout (depeneds on the + * instruction size). + * - NumThreadsPerBlock - Mfma instruction internal layout (depeneds on the + * instruction size). + * + * \tparam DataType Input data types. + * \tparam ATileLayout A tensor layout. + * \tparam BTileLayout B tensor layout. + * \tparam BlockSize Number of threads in block. + * \tparam GemmTraits Traits of gemm xdl operation. + * + * \return Vgpr c tensor for blockwise gemm. + */ +template +__host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr() +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + + constexpr bool is_integer = + is_same_v || is_same_v || is_same_v; + using GemmAccDataType = std::conditional_t; + + using ABlockDesc_K0_M_K1_Type = + decltype(detail::GetBlockDescriptor()); + using BBlockDesc_K0_N_K1_Type = + decltype(detail::GetBlockDescriptor()); + + using BlockwiseGemmXdlops = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + // Calcualte descriptor, shape and layout + constexpr auto vgpr_desc = BlockwiseGemmXdlops::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + const auto vgpr_shape = make_tuple(vgpr_desc.GetLengths()[I0], + vgpr_desc.GetLengths()[I1], + vgpr_desc.GetLengths()[I2], + vgpr_desc.GetLengths()[I3], + vgpr_desc.GetLengths()[I4], + vgpr_desc.GetLengths()[I5], + vgpr_desc.GetLengths()[I6], + vgpr_desc.GetLengths()[I7]); + const auto vgpr_layout = Layout, decltype(vgpr_desc)>( + vgpr_shape, vgpr_desc); + // Get vector type for Vgpr + using BlockwiseGemmCThreadBufferType = + remove_reference_t; + using VgprVectorType = typename BlockwiseGemmCThreadBufferType::V; + return ck::wrapper::make_register_tensor( + vgpr_layout); +} + +} // namespace wrapper +} // namespace ck diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index 57d79c5940..e344399dbf 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -10,8 +10,8 @@ namespace ck { namespace wrapper { -namespace detail { namespace { +namespace detail { /** * \brief Check if Tuple contains Slice object * @@ -187,8 +187,8 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple& const auto upper_dims = decltype(GenerateUpperDims<0>(TransformsTupleType{})){}; return transform_tensor_descriptor(flatten_desc, transforms, lower_dims, upper_dims); } -} // namespace } // namespace detail +} // namespace /** * \brief Tensor wrapper that performs static and dynamic buffer logic. @@ -209,7 +209,10 @@ struct Tensor public: using ElementSpaceSize = decltype(Layout{ Shape{}, UnrolledDescriptorType{}}.GetElementSpaceSize()); // SpaceSize type for buffer - using TensorElementType = ElementType; // DataType + using TensorElementType = std::conditional_t< + is_scalar_type::value, + ElementType, + typename scalar_type>::type>; // DataType static constexpr MemoryTypeEnum TensorBufferAddressSpace = BufferAddressSpace; static constexpr bool IsDynamicBuffer = !(BufferAddressSpace == MemoryTypeEnum ::Sgpr || @@ -280,7 +283,7 @@ struct Tensor * \return Requested value. */ template {}), bool> = false> - __host__ __device__ const ElementType& operator[](const Tuple& idx) const + __host__ __device__ const TensorElementType& operator[](const Tuple& idx) const { if constexpr(IsDynamicBuffer) { @@ -301,13 +304,13 @@ struct Tensor } template {}), bool> = false> - __host__ __device__ const ElementType& operator()(const Tuple& idx) const + __host__ __device__ const TensorElementType& operator()(const Tuple& idx) const { return this->operator[](idx); } template {}), bool> = false> - __host__ __device__ const ElementType& operator()(Idxs... idxs) const + __host__ __device__ const TensorElementType& operator()(Idxs... idxs) const { return this->operator[](make_tuple(idxs...)); } @@ -319,7 +322,7 @@ struct Tensor * \return Requested value. */ template {}), bool> = false> - __host__ __device__ ElementType& operator[](const Tuple& idx) + __host__ __device__ TensorElementType& operator[](const Tuple& idx) { if constexpr(IsDynamicBuffer) { @@ -340,13 +343,13 @@ struct Tensor } template {}), bool> = false> - __host__ __device__ ElementType& operator()(const Tuple& idx) + __host__ __device__ TensorElementType& operator()(const Tuple& idx) { return this->operator[](idx); } template {}), bool> = false> - __host__ __device__ ElementType& operator()(Idxs... idxs) + __host__ __device__ TensorElementType& operator()(Idxs... idxs) { return this->operator[](make_tuple(idxs...)); } @@ -366,7 +369,7 @@ struct Tensor * * \return Pointer. */ - __host__ __device__ ElementType* GetPointer() const { return buffer_.p_data_; } + __host__ __device__ TensorElementType* GetPointer() const { return buffer_.p_data_; } __host__ __device__ constexpr auto& GetBuffer() { return buffer_; } __host__ __device__ constexpr auto& GetBuffer() const { return buffer_; } @@ -395,10 +398,18 @@ struct Tensor ElementType, ElementSpaceSize, true /*InvalidElementUseNumericalZeroValue*/>; - using StaticBufferType = StaticBuffer; + using StaticBufferType = std::conditional_t< + is_scalar_type::value, + StaticBuffer, + StaticBufferTupleOfVector>::vector_size, + scalar_type>::vector_size, + true /*InvalidElementUseNumericalZeroValue*/>>; // If register use static buffer, else use dynamic buffer using Buffer = std::conditional_t; diff --git a/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp b/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp new file mode 100644 index 0000000000..24d863f5b1 --- /dev/null +++ b/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +namespace ck { +namespace wrapper { + +/** + * \brief Traits for blockwise gemm xdl. + * + * \tparam MPerXDLValue The MFMA instruction size in M dimension. + * \tparam NPerXDLValue The MFMA instruction size in N dimension. + * \tparam MXdlPerWaveValue The number of MFMA instructions run by single + * wave in M dimension. + * \tparam NXdlPerWaveValue The number of MFMA instructions run by single + * wave in N dimension. + * \tparam K1Value The number of K-dim elements that are packed together as + * a separate logical dimension. Usually aligns with vector load size. + */ +template +struct BlockwisGemmXdlTraits +{ + static constexpr index_t MPerXDL = MPerXDLValue; + static constexpr index_t NPerXDL = NPerXDLValue; + static constexpr index_t MXdlPerWave = MXdlPerWaveValue; + static constexpr index_t NXdlPerWave = NXdlPerWaveValue; + static constexpr index_t K1 = K1Value; +}; + +struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4> +{ +}; +struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 4> +{ +}; +struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4> +{ +}; + +} // namespace wrapper +} // namespace ck diff --git a/include/ck/wrapper/utils/tensor_partition.hpp b/include/ck/wrapper/utils/tensor_partition.hpp index 6aae5a92fe..5638382dba 100644 --- a/include/ck/wrapper/utils/tensor_partition.hpp +++ b/include/ck/wrapper/utils/tensor_partition.hpp @@ -6,6 +6,7 @@ #include "tensor_utils.hpp" #include "layout_utils.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_description/cluster_descriptor.hpp" @@ -14,6 +15,8 @@ namespace wrapper { namespace { +namespace detail { + /** * \brief Calculate shape for partition based on number of threads per each dim and * previous shape @@ -30,26 +33,109 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple{}; - const auto slice_len = size(shape) / thread_lengths.At(num_i); + const auto slice_len = + ck::math::integer_divide_ceil(size(shape), thread_lengths.At(num_i)); return slice_len; }, Number::Size()>{}); } +/** + * \brief Apply projection. + * + * \param base_tuple Tuple to apply projection. + * \param projection Projection to remove selected dim from partitioning. + * slice(X) to remove, where X is dim size, Number<1>{} to keep. + * \return Multi index after projection. + */ +template +__host__ __device__ constexpr auto +ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple, + [[maybe_unused]] const ProjectionTuple& projection) +{ + if constexpr(is_same_v>) + { + return Tuple<>{}; + } + else + { + auto base_tuple_after_projection = generate_tuple( + [&](auto i) { + const auto i_num = Number{}; + static_assert( + is_detected>::value || + is_same_v, Number<1>>); + if constexpr(is_detected>::value) + { + // When slice (to remove), then insert empty tuple (will be removed in next + // step). + return Tuple<>{}; + } + else + { + return base_tuple.At(i_num); + } + }, + Number{}); + // Remove empty tuples + return UnrollNestedTuple<0, 1>(base_tuple_after_projection); + } +} + +/** + * \brief Calculate shape with dims from projection. + * + * \param shape Base tensor shape. + * \param projection Projection to remove selected dim from partitioning. + * slice(X) to remove, where X is dim size, Number<1>{} to keep. + * \return Shape with dims from projection + */ +template +__host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple& shape, + const Tuple& projection) +{ + return generate_tuple( + [&](auto i) { + if constexpr(is_detected>>::value) + { + return size(projection).to_; + } + else + { + // number of shape element in actual fragment of shape and projection (method to + // calculate shape idx) + constexpr index_t shape_i = + detail::ApplyProjection(TupleSlice<0, i>(Tuple{}), + TupleSlice<0, i>(Tuple{})) + .Size(); + return size(shape); + } + }, + Number::Size()>{}); +} + /** * \brief Calculate total number of blocks. * * \param shape Base tensor shape. * \param tile_shape Tile shape. + * \param projection Projection is used to remove selected dim from + * partitioning. Use `slice(X)` to remove dimension, where X is dim + * size. Use `Number<1>{}` to keep it. * \return Tuple with blocks number. */ -template +template __host__ __device__ constexpr auto CalculateGridSize(const Tuple& shape, - const Tuple& tile_shape) + const Tuple& tile_shape, + const Tuple& projection) { - static_assert(Tuple::Size() == Tuple::Size(), "Wrong thread_lengths shape."); - return generate_tuple([&](auto i) { return size(shape) / size(tile_shape); }, - Number::Size()>{}); + auto shape_with_projection = CalculateShapeWithProjection(shape, projection); + return generate_tuple( + [&](auto i) { + return ck::math::integer_divide_ceil(size(shape_with_projection), + size(tile_shape)); + }, + Number::Size()>{}); } /** @@ -69,8 +155,75 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs, return thread_idxs * partition_lengths_seq + old_offset_idxs; } +/** + * \brief Calculate default projection. + * + * \param tile_shape Tile shape. + * \return Default projection (filled with Number<1>{}). + */ +template +__host__ __device__ constexpr auto +GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape) +{ + return generate_tuple([&](auto) { return Number<1>{}; }, Number{}); +} + +} // namespace detail } // namespace +/** + * \brief Create local partition for thread (At now only packed partition + * is supported). + * + * \param tensor Tensor for partition. + * \param thread_lengths Layout of threads (could not be nested). + * \param thread_id Thread index represented as integer. + * \param projection Projection is used to remove selected dim from + * partitioning. Use `slice(X)` to remove dimension, where X is dim + * size. Use `Number<1>{}` to keep it. + * \return Partition tensor. + */ +template +__host__ __device__ constexpr auto +make_local_partition(TensorType& tensor, + [[maybe_unused]] const ThreadLengthsTuple& thread_lengths, + const index_t thread_id, + const ProjectionTuple& projection) +{ + static_assert(!IsNestedTuple(ThreadLengthsTuple{})); + // Calculate new partition shape + const auto& tensor_shape = shape(tensor); + // Calculate projected thread lengths + constexpr auto projected_thread_lengths = + detail::ApplyProjection(ThreadLengthsTuple{}, ProjectionTuple{}); + constexpr auto partition_shape = + detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths); + // Create Thread Cluster Descriptor + constexpr auto partition_shape_seq = + generate_sequence_v2([&](auto I) { return size(partition_shape); }, + Number{}); + constexpr auto thread_lengths_seq = + generate_sequence_v2([&](auto I) { return size(ThreadLengthsTuple{}); }, + Number{}); + constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq); + // Calculate thread idxs and offsets + const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id)); + // Apply projection on thread idxs to remove not needed idxs + const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection); + const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs( + projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets()); + // Create new layout and tensor + auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor(); + const auto partition_layout = + Layout, decltype(unrolled_desc)>( + partition_shape, unrolled_desc); + auto partition_tensor = + make_tensor(tensor.GetPointer(), partition_layout); + // Apply offsets + partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs)); + return partition_tensor; +} + /** * \brief Create local partition for thread (At now only packed partition * is supported). @@ -81,37 +234,12 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs, * \return Partition tensor. */ template -__host__ __device__ constexpr auto -make_local_partition(TensorType& tensor, - [[maybe_unused]] const ThreadLengthsTuple& thread_lengths, - const index_t thread_id) +__host__ __device__ constexpr auto make_local_partition(TensorType& tensor, + const ThreadLengthsTuple& thread_lengths, + const index_t thread_id) { - static_assert(!IsNestedTuple(ThreadLengthsTuple{})); - // Calculate new partition shape - const auto& tensor_shape = shape(tensor); - constexpr auto partition_shape = - CalculateLocalPartitionShape(decltype(tensor_shape){}, ThreadLengthsTuple{}); - // Create Thread Cluster Descriptor - constexpr auto partition_lengths_seq = generate_sequence_v2( - [&](auto I) { return size(partition_shape); }, Number{}); - constexpr auto thread_lengths_seq = - generate_sequence_v2([&](auto I) { return size(ThreadLengthsTuple{}); }, - Number{}); - constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq); - // Calculate thread idxs and offsets - const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id)); - const auto offset_multi_idxs = - CalculateOffsetMultiIdxs(thread_idxs, partition_lengths_seq, tensor.GetMultiIdxOffsets()); - // Create new layout and tensor - auto& flatten_desc = layout(tensor).GetUnrolledDescriptor(); - const auto partition_layout = - Layout, decltype(flatten_desc)>( - partition_shape, flatten_desc); - auto partition_tensor = - make_tensor(tensor.GetPointer(), partition_layout); - // Apply offsets - partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs)); - return partition_tensor; + const auto projection = detail::GenerateDefaultProjection(ThreadLengthsTuple{}); + return make_local_partition(tensor, thread_lengths, thread_id, projection); } /** @@ -125,22 +253,29 @@ make_local_partition(TensorType& tensor, * \param tensor Tensor for partition. * \param tile_shape Shapes of requested tile. * \param block_id Block index represented as integer. - + * \param projection Projection to remove selected dim from partitioning. + * slice(X) to remove, where X is dim size, Number<1>{} to keep. * \return Tile tensor. */ -template -__host__ __device__ constexpr auto -make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id) +template +__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, + const BlockShapeTuple& tile_shape, + const index_t block_id, + const ProjectionTuple& projection) { static_assert(!IsNestedTuple(BlockShapeTuple{})); + constexpr bool is_default_projection = + is_same_v; + constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor(); - if constexpr(BlockShapeTuple::Size() == I2) + // TODO: Enable block_2_tile_map partitioning for non-default projection. + if constexpr(BlockShapeTuple::Size() == I2 && is_default_projection) { // Optimized version for 2d tile shape [MxK] const auto block_2_tile_map = @@ -169,20 +304,24 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con { // Calculate offsets // Sequence with data to process per block - constexpr auto tile_shape_seq = - generate_sequence_v2([](auto I) { return size(BlockShapeTuple{}.At(I)); }, - Number{}); + constexpr auto projected_tile_shape = + detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{}); + using ProjectedTileShapeTuple = decltype(projected_tile_shape); + constexpr auto projected_tile_shape_seq = + generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); }, + Number{}); // Tuple with number of blocks - const auto block_lengths = CalculateGridSize(shape(tensor), tile_shape); - constexpr auto block_cluster_desc_ = make_cluster_descriptor(block_lengths); + const auto block_lengths = detail::CalculateGridSize(shape(tensor), tile_shape, projection); + const auto block_cluster_desc_ = make_cluster_descriptor(block_lengths); const auto block_idxs = block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id)); - const auto offset_multi_idxs = - CalculateOffsetMultiIdxs(block_idxs, tile_shape_seq, tensor.GetMultiIdxOffsets()); + const auto projected_block_idxs = detail::ApplyProjection(block_idxs, projection); + const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs( + projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets()); // Create new layout and tensor const auto tile_layout = - Layout, decltype(aligned_desc)>(tile_shape, - aligned_desc); + Layout, decltype(aligned_desc)>( + projected_tile_shape, aligned_desc); auto tile_tensor = make_tensor(tensor.GetPointer(), tile_layout); // Apply offsets @@ -191,5 +330,61 @@ make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, con } } +/** + * \brief Create local tile for thread block. (At now only packed tile + * is supported). + * + * \note Currently to get the best performance please use 2d shape. + * + * \param tensor Tensor for partition. + * \param tile_shape Shapes of requested tile. + * \param block_id Block index represented as integer. + * \return Tile tensor. + */ +template +__host__ __device__ constexpr auto +make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id) +{ + const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{}); + return make_local_tile(tensor, tile_shape, block_id, projection); +} + +/** + * \brief Pad tensor shapes to be adjusted to tile lengths. + * + * + * \param tensor Tensor to pad. + * \param tile_lengths Tile lengths to align tensor shape. + * \return Padded tensor. + */ +template +__host__ __device__ constexpr auto pad(const TensorType& tensor, const TileLengths& tile_lengths) +{ + const auto& tensor_shape = shape(tensor); + using TensorShapeType = remove_reference_t; + auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor(); + // Generate sequence with ones to mark that all dims will be padded + constexpr auto do_pads_seq = + generate_sequence_v2([](auto) { return Number<1>{}; }, Number{}); + // Create descriptor with padding + auto padded_desc = + tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq); + // Generate padded shape + const auto padded_shape = generate_tuple( + [&](auto i) { + const auto& dim = size(tensor_shape); + const auto& tile_length = size(tile_lengths); + return ck::math::integer_divide_ceil(dim, tile_length) * tile_length; + }, + Number{}); + // Create layout and tensor + const auto padded_layout = + Layout(padded_shape, padded_desc); + auto partition_tensor = + make_tensor(tensor.GetPointer(), padded_layout); + partition_tensor.SetMultiIdxOffset(tensor.GetMultiIdxOffsets()); + return partition_tensor; +} + } // namespace wrapper } // namespace ck diff --git a/include/ck/wrapper/utils/tensor_utils.hpp b/include/ck/wrapper/utils/tensor_utils.hpp index 7ec080760a..ee9e438a40 100644 --- a/include/ck/wrapper/utils/tensor_utils.hpp +++ b/include/ck/wrapper/utils/tensor_utils.hpp @@ -5,6 +5,7 @@ #include "ck/ck.hpp" +#include "ck/utility/data_type.hpp" #include "ck/utility/number.hpp" #include "ck/utility/tuple.hpp" #include "ck/utility/tuple_helper.hpp" @@ -19,9 +20,9 @@ namespace wrapper { * \brief Memory type, allowed members: * - Generic, * - Global, - * - LDS, - * - SGPR, - * - VGPR, + * - Lds, + * - Sgpr, + * - Vgpr, */ using MemoryTypeEnum = AddressSpaceEnum; @@ -52,12 +53,8 @@ struct Slice __host__ __device__ constexpr auto range(const T& dim) const { if constexpr(is_same_v || is_same_v || - is_same_v) + is_same_v, index_t>) { - if(!(dim >= to_ && from_ >= 0 && (to_ < 0 || to_ > from_))) - { - throw std::runtime_error("Invalid range"); - } if(to_ < 0) { return dim - from_ + to_ + 1; @@ -70,9 +67,10 @@ struct Slice } else { - static_assert(dim >= to_ && from_ >= Number<0>{} && (to_ < 0 || to_ > from_), + static_assert(T{} >= ToType{} && FromType{} >= Number<0>{} && + (ToType{} < 0 || ToType{} > FromType{}), "Invalid range"); - if constexpr(to_ < 0) + if constexpr(ToType{} < 0) { return dim - from_ + to_ + Number<1>{}; } @@ -130,6 +128,23 @@ constexpr auto make_register_tensor(const Layout& return Tensor(layout); } +/** + * \brief Clear tensor. (Only for Vpgr/Sgpr) + * + * \param tensor Tensor to be cleared. + */ +template +__host__ __device__ void +clear(Tensor& tensor) +{ + static_assert( + !Tensor::IsDynamicBuffer); + return tensor.GetBuffer().Clear(); +} + /** * \brief Get Tensor Layout. * diff --git a/test/wrapper/CMakeLists.txt b/test/wrapper/CMakeLists.txt index 6c3e29ab87..cadc146795 100644 --- a/test/wrapper/CMakeLists.txt +++ b/test/wrapper/CMakeLists.txt @@ -6,3 +6,9 @@ add_gtest_executable(test_copy test_copy.cpp) target_link_libraries(test_copy PRIVATE utility) add_gtest_executable(test_partition test_partition.cpp) target_link_libraries(test_partition PRIVATE utility) +if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR + GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR + GPU_TARGETS MATCHES "gfx942") + add_gtest_executable(test_gemm test_gemm.cpp) + target_link_libraries(test_gemm PRIVATE utility) +endif() diff --git a/test/wrapper/test_gemm.cpp b/test/wrapper/test_gemm.cpp new file mode 100644 index 0000000000..b26cd5fed6 --- /dev/null +++ b/test/wrapper/test_gemm.cpp @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/library/utility/host_tensor.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/operations/gemm.hpp" + +template +void CheckResult(const std::vector& a_data, + const std::vector& b_data, + std::vector& c_m_n_device_result, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + + Tensor a_m_k(HostTensorDescriptor({M, K})); + Tensor b_k_n(HostTensorDescriptor({K, N}, {1, K})); + Tensor c_m_n_host_result(HostTensorDescriptor({M, N})); + + a_m_k.mData = a_data; + b_k_n.mData = b_data; + + auto ref_op = ReferenceGemmInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + auto ref_argument = ref_op.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + EXPECT_TRUE(ck::utils::check_err(c_m_n_device_result, c_m_n_host_result.mData)); +} + +template +__global__ void DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayoutShape thread_layout) +{ + constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); + constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); + constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); + + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto c_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{})); + + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_global_layout); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_global_layout); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_global_layout); + + auto a_padded_global_tensor = ck::wrapper::pad(a_global_tensor, shape(a_tile_layout)); + auto b_padded_global_tensor = ck::wrapper::pad(b_global_tensor, shape(b_tile_layout)); + auto c_padded_global_tensor = ck::wrapper::pad(c_global_tensor, shape(c_tile_layout)); + + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)]; + + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); + + const ck::index_t block_idx = static_cast(blockIdx.x); + using DimAccessOrder = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t vector_dim = 1; + + auto c_global_local_tile = ck::wrapper::make_local_tile( + c_padded_global_tensor, + tile_shape, + block_idx, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + ck::wrapper::clear(c_vgpr_reg); + + const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock); + ck::index_t i = 0; + do + { + const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock); + auto a_padded_global_tensor_k_slice = a_padded_global_tensor(ck::wrapper::slice(), k_slice); + auto b_padded_global_tensor_k_slice = b_padded_global_tensor(ck::wrapper::slice(), k_slice); + auto a_global_local_tile = ck::wrapper::make_local_tile( + a_padded_global_tensor_k_slice, + tile_shape, + block_idx, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{})); + auto b_global_local_tile = ck::wrapper::make_local_tile( + b_padded_global_tensor_k_slice, + tile_shape, + block_idx, + make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{})); + + ck::wrapper::blockwise_copy( + a_global_local_tile, a_lds_tensor, thread_layout); + ck::wrapper::blockwise_copy( + b_global_local_tile, b_lds_tensor, thread_layout); + ck::block_sync_lds(); + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + + ++i; + } while(i < num_loop); + + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +} + +template +void PerformGemm(const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape& tile_shape, + const ThreadLayoutShape& thread_layout) +{ + // Global memory buffers + DeviceMem a_mem(M * K * sizeof(DataType)); + DeviceMem b_mem(K * N * sizeof(DataType)); + DeviceMem c_mem(M * N * sizeof(DataType)); + + std::vector a_data(M * K); + std::vector b_data(K * N); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_data); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_data); + + a_mem.ToDevice(a_data.data()); + b_mem.ToDevice(b_data.data()); + c_mem.SetZero(); + + const ck::index_t grid_size = + ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)) * + ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); + + const auto kernel = + DeviceGemm; + launch_and_time_kernel(StreamConfig{nullptr}, + kernel, + dim3(grid_size), + dim3(ck::wrapper::size(thread_layout)), + 0, + a_mem.GetDeviceBuffer(), + b_mem.GetDeviceBuffer(), + c_mem.GetDeviceBuffer(), + M, + N, + K, + tile_shape, + thread_layout); + + std::vector c_data(M * N); + c_mem.FromDevice(c_data.data()); + + CheckResult(a_data, b_data, c_data, M, N, K); +} + +TEST(TestGemm, Float) +{ + using DataType = float; + const auto thread_layout = ck::make_tuple(ck::Number<16>{}, ck::Number<16>{}); + const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout); + // Irregular case + PerformGemm( + 129, 129, 67, tile_shape, thread_layout); +} + +TEST(TestGemm, Int8) +{ + using DataType = int8_t; + const auto thread_layout = ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}); + const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout); + // Irregular case + PerformGemm( + 129, 129, 67, tile_shape, thread_layout); +} + +TEST(TestGemm, Half) +{ + using DataType = ck::half_t; + const auto thread_layout = ck::make_tuple(ck::Number<32>{}, ck::Number<8>{}); + const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout); + // Irregular case + PerformGemm( + 129, 129, 67, tile_shape, thread_layout); +} + +TEST(TestGemm, Float_2x4_4x2_XdlPerWave) +{ + using DataType = float; + const auto thread_layout_4x2_xdl_per_wave = ck::make_tuple(ck::Number<16>{}, ck::Number<8>{}); + const auto thread_layout_2x4_xdl_per_wave = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}); + const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout_4x2_xdl_per_wave); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout_2x4_xdl_per_wave); +} diff --git a/test/wrapper/test_partition.cpp b/test/wrapper/test_partition.cpp index cacbfe9d88..8b6d220cd7 100644 --- a/test/wrapper/test_partition.cpp +++ b/test/wrapper/test_partition.cpp @@ -29,17 +29,24 @@ TEST(TestPartition, LocalPartition) const auto tensor = ck::wrapper::make_tensor(data.data(), layout); - const auto thread_steps = ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}); - const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}); + const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{}); + const auto thread_layout = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{}); + // 3d partition on 2d shape (calculate partition on 3d thread layout, and then skip first dim) + const auto thread_projection = + ck::make_tuple(ck::wrapper::slice(4), ck::Number<1>{}, ck::Number<1>{}); + constexpr ck::index_t projection_thread_length = ck::Number<4>{}; - for(ck::index_t thread_id = 0; thread_id < ck::wrapper::size(thread_layout); thread_id++) + for(ck::index_t thread_id = 0; + thread_id < ck::wrapper::size(thread_layout) / projection_thread_length; + thread_id++) { const auto packed_partition = - ck::wrapper::make_local_partition(tensor, thread_layout, thread_id); + ck::wrapper::make_local_partition(tensor, thread_layout, thread_id, thread_projection); const auto expected_partition_size = - ck::wrapper::size(tensor) / ck::wrapper::size(thread_layout); - const auto expected_partition_first_val = thread_id * ck::wrapper::size<0>(thread_steps); + ck::wrapper::size(tensor) / + (ck::wrapper::size(thread_layout) / projection_thread_length); + const auto expected_partition_first_val = thread_id * ck::wrapper::size<1>(thread_steps); const auto expected_partition_second_val = expected_partition_first_val + 1; EXPECT_EQ(ck::wrapper::size(packed_partition), expected_partition_size); EXPECT_EQ(packed_partition(0), expected_partition_first_val); @@ -58,8 +65,12 @@ TEST(TestPartition, LocalTile) const auto tensor = ck::wrapper::make_tensor(data.data(), layout); - - const auto block_shape = ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}); + // 4d tile partitioning on 3d shape (calculate tile on 4d tile layout, and then skip last dim) + const auto block_shape = + ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}, ck::Number<2>{}); + const auto block_projection = + ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(2)); + constexpr ck::index_t projection_block_dim = ck::Number<2>{}; const auto num_blocks = ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape), ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape), @@ -69,9 +80,10 @@ TEST(TestPartition, LocalTile) for(auto block_idx : block_idxs) { - const auto packed_tile = ck::wrapper::make_local_tile(tensor, block_shape, block_idx); + const auto packed_tile = + ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_projection); - const auto expected_tile_size = ck::wrapper::size(block_shape); + const auto expected_tile_size = ck::wrapper::size(block_shape) / projection_block_dim; auto expected_tile_first_val = (block_idx % ck::wrapper::size<2>(num_blocks)) * ck::wrapper::size<2>(block_shape) * ck::wrapper::size<2>(strides); From 112b691bb77c6ea6d6fd651c6657b9864c4b6517 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 31 Jan 2024 13:27:17 -0800 Subject: [PATCH 66/75] add new performance tests for mixed fp16/fp8 gemms (#1151) --- script/parse_perf_data.py | 290 --------------------------- script/process_perf_data.py | 6 +- script/profile_mixed_gemm.sh | 52 +++++ script/run_full_performance_tests.sh | 6 + 4 files changed, 63 insertions(+), 291 deletions(-) delete mode 100644 script/parse_perf_data.py create mode 100755 script/profile_mixed_gemm.sh diff --git a/script/parse_perf_data.py b/script/parse_perf_data.py deleted file mode 100644 index 4cb13e6243..0000000000 --- a/script/parse_perf_data.py +++ /dev/null @@ -1,290 +0,0 @@ -#!/usr/bin/env python3 -import os, io, argparse, datetime, re -import numpy as np -import sqlalchemy -from sqlalchemy.types import NVARCHAR, Float, Integer -import pymysql -import pandas as pd -from sshtunnel import SSHTunnelForwarder - -def print_to_string(*args, **kwargs): - output = io.StringIO() - print(*args, file=output, **kwargs) - contents = output.getvalue() - output.close() - return contents - -def parse_args(): - parser = argparse.ArgumentParser(description='Parse results from tf benchmark runs') - parser.add_argument('filename', type=str, help='Log file to prase or directory containing log files') - args = parser.parse_args() - files = [] - if os.path.isdir(args.filename): - all_files = os.listdir(args.filename) - for name in all_files: - if not 'log' in name: - continue - files.append(os.path.join(args.filename, name)) - else: - files = [args.filename] - args.files = files - return args - -def main(): - args = parse_args() - tests = [] - kernels=[] - tflops=[] - dtype=[] - alayout=[] - blayout=[] - M=[] - N=[] - K=[] - StrideA=[] - StrideB=[] - StrideC=[] - #parse results, get the Tflops value for "Best Perf" kernels - - glue="" - for filename in args.files: - for line in open(filename): - if 'Branch name' in line: - lst=line.split() - branch_name=lst[2] - if 'On branch' in line: - lst=line.split() - branch_name=lst[2] - if 'Node name' in line: - lst=line.split() - node_id=lst[2] - if 'GPU_arch' in line: - lst=line.split() - gpu_arch=lst[2] - if 'HIP version' in line: - lst=line.split() - hip_vers=lst[2] - if 'Compute Unit' in line: - lst=line.split() - compute_units=lst[2] - if 'InstalledDir' in line: - lst=line.split() - rocm_vers=lst[1][lst[1].find('/opt/rocm-')+len('/opt/rocm-'):lst[1].rfind('/llvm/bin')] - print("Branch name:",branch_name) - print("Node name:",node_id) - print("GPU_arch:",gpu_arch) - print("Compute units:",compute_units) - print("ROCM_version:",rocm_vers) - print("HIP_version:",hip_vers) - - - #parse gemm performance tests: - if 'gemm' in filename: - for filename in args.files: - for line in open(filename): - if 'Best Perf' in line: - lst=line.split() - if len(lst)>=37: #the line is complete - tests.append(glue.join(lst[5:30])) - kernels.append(glue.join(lst[37:])) - tflops.append(lst[33]) - dtype.append(lst[5]) - alayout.append(lst[8]) - blayout.append(lst[11]) - M.append(lst[14]) - N.append(lst[17]) - K.append(lst[20]) - StrideA.append(lst[23]) - StrideB.append(lst[26]) - StrideC.append(lst[29]) - elif len(lst)<37 and len(lst)>=33: #the tflops are available - tests.append(glue.join(lst[5:30])) - kernels.append("N/A") - tflops.append(lst[33]) - dtype.append(lst[5]) - alayout.append(lst[8]) - blayout.append(lst[11]) - M.append(lst[14]) - N.append(lst[17]) - K.append(lst[20]) - StrideA.append(lst[23]) - StrideB.append(lst[26]) - StrideC.append(lst[29]) - print("warning: incomplete line:",lst) - elif len(lst)<33: #even the tflops are not available - print("Error in ckProfiler output!") - print("warning: incomplete line=",lst) - #sort results - #sorted_tests = sorted(tests) - #print("sorted tests:",sorted_tests) - sorted_tflops = [x for _,x in sorted(zip(tests,tflops))] - #sorted_kernels = [x for _,x in sorted(zip(tests,kernels))] - test_list=list(range(1,len(tests)+1)) - - #parse resnet50 performance tests: - if 'resnet50' in filename: - for filename in args.files: - for line in open(filename): - if 'Best Perf' in line: - lst=line.split() - tflops.append(lst[4]) - - print("Number of tests:",len(tflops)) - sql_hostname = '127.0.0.1' - sql_username = os.environ["dbuser"] - sql_password = os.environ["dbpassword"] - sql_main_database = 'miopen_perf' - sql_port = 3306 - ssh_host = os.environ["dbsship"] - ssh_user = os.environ["dbsshuser"] - ssh_port = int(os.environ["dbsshport"]) - ssh_pass = os.environ["dbsshpassword"] - - with SSHTunnelForwarder( - (ssh_host, ssh_port), - ssh_username=ssh_user, - ssh_password=ssh_pass, - remote_bind_address=(sql_hostname, sql_port)) as tunnel: - - sqlEngine = sqlalchemy.create_engine('mysql+pymysql://{0}:{1}@{2}:{3}/{4}'. - format(sql_username, sql_password, sql_hostname, tunnel.local_bind_port, sql_main_database)) - conn = sqlEngine.connect() - - #save gemm performance tests: - if 'gemm' in filename: - - #write the ck_gemm_test_params table - #only needed once the test set changes - ''' - sorted_dtypes = [x for _,x in sorted(zip(tests,dtype))] - sorted_alayout = [x for _,x in sorted(zip(tests,alayout))] - sorted_blayout = [x for _,x in sorted(zip(tests,blayout))] - sorted_M = [x for _,x in sorted(zip(tests,M))] - sorted_N = [x for _,x in sorted(zip(tests,N))] - sorted_K = [x for _,x in sorted(zip(tests,K))] - sorted_StrideA = [x for _,x in sorted(zip(tests,StrideA))] - sorted_StrideB = [x for _,x in sorted(zip(tests,StrideB))] - sorted_StrideC = [x for _,x in sorted(zip(tests,StrideC))] - ck_gemm_params=[test_list,sorted_dtypes,sorted_alayout,sorted_blayout, - sorted_M,sorted_N,sorted_K,sorted_StrideA,sorted_StrideB, - sorted_StrideC] - df=pd.DataFrame(np.transpose(ck_gemm_params),columns=['Test_number','Data_type', - 'Alayout','BLayout','M','N','K', 'StrideA','StrideB','StrideC']) - print(df) - - dtypes = { - 'Test_number': Integer(), - 'Data_type': NVARCHAR(length=5), - 'Alayout': NVARCHAR(length=12), - 'Blayout': NVARCHAR(length=12), - 'M': Integer(), - 'N': Integer(), - 'K': Integer(), - 'StrideA': Integer(), - 'StrideB': Integer(), - 'StrideC': Integer() - } - df.to_sql("ck_gemm_test_params",conn,if_exists='replace',index=False, dtype=dtypes) - ''' - - #read baseline results for the latest develop branch - query = '''SELECT * from ck_gemm_tflops WHERE Datetime = (SELECT MAX(Datetime) FROM ck_gemm_tflops where Branch_ID='develop' );''' - tflops_base = pd.read_sql_query(query, conn) - - #write new results to the db - testlist=[] - for i in range(1,len(tests)+1): - testlist.append("Test%i"%i) - ck_gemm_tflops=[str(branch_name),str(node_id),str(gpu_arch),compute_units,str(rocm_vers),str(hip_vers),str(datetime.datetime.now())] - flops=pd.DataFrame(data=[ck_gemm_tflops],columns=['Branch_ID','Node_ID','GPU_arch','Compute Units','ROCM_version','HIP_version','Datetime']) - df_add=pd.DataFrame(data=[sorted_tflops],columns=testlist) - flops=pd.concat([flops,df_add],axis=1) - print("new tflops for gemm tests:",flops) - flops.to_sql("ck_gemm_tflops",conn,if_exists='append',index=False) - - #save resnet50 performance tests: - if 'resnet50' in filename: - #read baseline results for the latest develop branch - query = '''SELECT * from ck_resnet50_N256_tflops WHERE Datetime = (SELECT MAX(Datetime) FROM ck_resnet50_N256_tflops where Branch_ID='develop' );''' - tflops_base_N256 = pd.read_sql_query(query, conn) - query = '''SELECT * from ck_resnet50_N4_tflops WHERE Datetime = (SELECT MAX(Datetime) FROM ck_resnet50_N4_tflops where Branch_ID='develop' );''' - tflops_base_N4 = pd.read_sql_query(query, conn) - - #write new results to the db - testlist=[] - for i in range(1,50): - testlist.append("Layer%i"%i) - ck_resnet_tflops=[str(branch_name),str(node_id),str(gpu_arch),compute_units,str(rocm_vers),str(hip_vers),str(datetime.datetime.now())] - flops0=pd.DataFrame(data=[ck_resnet_tflops],columns=['Branch_ID','Node_ID','GPU_arch','Compute Units','ROCM_version','HIP_version','Datetime']) - df_add=pd.DataFrame(data=[tflops[0:49]],columns=testlist) - flops=pd.concat([flops0,df_add],axis=1) - print("new tflops for N=256 resnet50 test:",flops) - flops.to_sql("ck_resnet50_N256_tflops",conn,if_exists='append',index=False) - df_add=pd.DataFrame(data=[tflops[49:98]],columns=testlist) - flops=pd.concat([flops0,df_add],axis=1) - print("new tflops for N=4 resnet50 test:",flops) - flops.to_sql("ck_resnet50_N4_tflops",conn,if_exists='append',index=False) - - conn.close() - - #compare the results to the baseline if baseline exists - regression=0 - if 'gemm' in filename: - if not tflops_base.empty: - base=tflops_base[testlist].to_numpy(dtype='float') - base_list=base[0] - ave_perf=0 - for i in range(len(base_list)): - # success criterion: - if base_list[i]>1.01*float(sorted_tflops[i]): - print("test # ",i,"shows regression by {:.3f}%".format( - (float(sorted_tflops[i])-base_list[i])/base_list[i]*100)) - regression=1 - ave_perf=ave_perf+float(sorted_tflops[i])/base_list[i] - if regression==0: - print("no regressions found") - ave_perf=ave_perf/len(base_list) - print("average performance relative to baseline:",ave_perf) - else: - print("could not find a baseline") - if 'resnet50' in filename: - if not tflops_base_N256.empty: - base=tflops_base_N256[testlist].to_numpy(dtype='float') - base_list=base[0] - ave_perf=0 - for i in range(len(base_list)): - # success criterion: - if base_list[i]>1.01*float(tflops[i]): - print("layer # ",i,"shows regression by {:.3f}%".format( - (float(tflops[i])-base_list[i])/base_list[i]*100)) - regression=1 - ave_perf=ave_perf+float(tflops[i])/base_list[i] - if regression==0: - print("no regressions found") - ave_perf=ave_perf/len(base_list) - print("average performance relative to baseline:",ave_perf) - else: - print("could not find a baseline for N=256") - if not tflops_base_N4.empty: - base=tflops_base_N4[testlist].to_numpy(dtype='float') - base_list=base[0] - ave_perf=0 - for i in range(len(base_list)): - # success criterion: - if base_list[i]>1.01*float(tflops[i+49]): - print("layer # ",i,"shows regression by {:.3f}%".format( - (float(tflops[i+49])-base_list[i])/base_list[i]*100)) - regression=1 - ave_perf=ave_perf+float(tflops[i+49])/base_list[i] - if regression==0: - print("no regressions found") - ave_perf=ave_perf/len(base_list) - print("average performance relative to baseline:",ave_perf) - else: - print("could not find a baseline for N=4") - - #return 0 if performance criteria met, otherwise return 1 - return regression - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/script/process_perf_data.py b/script/process_perf_data.py index e8b8e1458c..d7e40569fd 100644 --- a/script/process_perf_data.py +++ b/script/process_perf_data.py @@ -133,7 +133,7 @@ def parse_logfile(logfile): if 'Best Perf' in line: lst=line.split() res.append(lst[4]) - elif 'onnx_gemm' in logfile or 'splitK_gemm' in logfile: + elif 'onnx_gemm' in logfile or 'splitK_gemm' in logfile or 'mixed_gemm' in logfile: for line in open(logfile): if 'Best Perf' in line: lst=line.split() @@ -295,6 +295,10 @@ def main(): for i in range(1,len(results)+1): testlist.append("Test%i"%i) table_name="ck_splitK_gemm_tflops" + if 'mixed_gemm' in filename: + for i in range(1,len(results)+1): + testlist.append("Test%i"%i) + table_name="ck_mixed_gemm_tflops" tflops_base = get_baseline(table_name,conn) store_new_test_result(table_name, results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, conn) diff --git a/script/profile_mixed_gemm.sh b/script/profile_mixed_gemm.sh new file mode 100755 index 0000000000..383c7ea36e --- /dev/null +++ b/script/profile_mixed_gemm.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +DRIVER="../build/bin/ckProfiler" +echo $DRIVER +OP=$1 +DATATYPE=$2 +LAYOUT=$3 +VERIFY=$4 +INIT=$5 +LOG=$6 +TIME=$7 +KBatch=$8 + +######## op datatype layout verify init log time M___ N___ K___ StrideA StrideB StrideC KBatch_ + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 16 16 1024 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 16 16 8192 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 16 16 65536 -1 -1 -1 $KBatch + + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 16 2048 1024 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 16 2048 8192 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 16 2048 65536 -1 -1 -1 $KBatch + + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 16 8192 1024 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 16 8192 8192 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 16 8192 65536 -1 -1 -1 $KBatch + + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 16 1024 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 16 8192 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 16 65536 -1 -1 -1 $KBatch + + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 1024 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 8192 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 2048 65536 -1 -1 -1 $KBatch + + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 8192 1024 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 8192 8192 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 2048 8192 65536 -1 -1 -1 $KBatch + + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 16 1024 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 16 8192 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 16 65536 -1 -1 -1 $KBatch + + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 2048 1024 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 2048 8192 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 2048 65536 -1 -1 -1 $KBatch + + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 8192 1024 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 8192 8192 -1 -1 -1 $KBatch + $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $TIME 8192 8192 65536 -1 -1 -1 $KBatch + \ No newline at end of file diff --git a/script/run_full_performance_tests.sh b/script/run_full_performance_tests.sh index eae334ae2d..90678389fa 100755 --- a/script/run_full_performance_tests.sh +++ b/script/run_full_performance_tests.sh @@ -147,3 +147,9 @@ export onnx_log="perf_onnx_gemm.log" print_log_header $onnx_log $env_type $branch $host_name ./profile_onnx_gemm.sh gemm 0 0 $verify 1 0 1 2>&1 | tee -a $onnx_log ./profile_onnx_gemm.sh gemm 1 0 $verify 1 0 1 2>&1 | tee -a $onnx_log + +#run mixed fp16/fp8 and fp8/fp16 gemm tests +export mixed_gemm_log="perf_mixed_gemm.log" +print_log_header $mixed_gemm_log $env_type $branch $host_name +./profile_mixed_gemm.sh gemm_splitk 4 0 $verify 2 0 1 16 2>&1 | tee -a $mixed_gemm_log +./profile_mixed_gemm.sh gemm_splitk 5 0 $verify 2 0 1 16 2>&1 | tee -a $mixed_gemm_log \ No newline at end of file From 171ca260b506b32e53c899bdc580accb3469937c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 2 Feb 2024 20:25:54 +0100 Subject: [PATCH 67/75] Extend gemm traits number for ck wrapper (#1153) --- .../traits/blockwise_gemm_xdl_traits.hpp | 21 +++++++++++++++++++ test/wrapper/test_gemm.cpp | 8 +++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp b/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp index 24d863f5b1..8301636a9f 100644 --- a/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp +++ b/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp @@ -34,6 +34,7 @@ struct BlockwisGemmXdlTraits static constexpr index_t K1 = K1Value; }; +// K1 = 4 struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4> { }; @@ -43,6 +44,26 @@ struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits< struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4> { }; +// K1 = 8 +struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 8> +{ +}; +struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 8> +{ +}; +struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 8> +{ +}; +// K1 = 16 +struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 16> +{ +}; +struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 16> +{ +}; +struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 16> +{ +}; } // namespace wrapper } // namespace ck diff --git a/test/wrapper/test_gemm.cpp b/test/wrapper/test_gemm.cpp index b26cd5fed6..12245490d1 100644 --- a/test/wrapper/test_gemm.cpp +++ b/test/wrapper/test_gemm.cpp @@ -225,10 +225,10 @@ TEST(TestGemm, Int8) using DataType = int8_t; const auto thread_layout = ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}); const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); - PerformGemm( + PerformGemm( 512, 512, 128, tile_shape, thread_layout); // Irregular case - PerformGemm( + PerformGemm( 129, 129, 67, tile_shape, thread_layout); } @@ -237,10 +237,10 @@ TEST(TestGemm, Half) using DataType = ck::half_t; const auto thread_layout = ck::make_tuple(ck::Number<32>{}, ck::Number<8>{}); const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); - PerformGemm( + PerformGemm( 512, 512, 128, tile_shape, thread_layout); // Irregular case - PerformGemm( + PerformGemm( 129, 129, 67, tile_shape, thread_layout); } From 180f16f9acf1455acbdf45859963909de9f6169c Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 2 Feb 2024 11:35:26 -0800 Subject: [PATCH 68/75] Add support for more Navi2x and Navi3x models. (#1152) * add support for navi2x and navi3x models * fix syntax * use common macro for different mi300 architectures --- include/ck/ck.hpp | 43 ++++++++++++------- include/ck/host_utility/device_prop.hpp | 19 ++++++++ ...d_contraction_multiple_d_wmma_cshuffle.hpp | 3 +- ...ed_contraction_multiple_d_xdl_cshuffle.hpp | 2 +- .../device_batched_gemm_e_permute_xdl.hpp | 2 +- .../device_batched_gemm_gemm_xdl_cshuffle.hpp | 2 +- .../impl/device_batched_gemm_multi_d_xdl.hpp | 2 +- .../device_batched_gemm_multiple_d_dl.hpp | 12 ++---- ...ultiple_d_gemm_multiple_d_xdl_cshuffle.hpp | 2 +- ...evice_batched_gemm_reduce_xdl_cshuffle.hpp | 2 +- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 2 +- ...batched_gemm_softmax_gemm_xdl_cshuffle.hpp | 2 +- .../device/impl/device_batched_gemm_xdl.hpp | 2 +- ..._contraction_multiple_abd_xdl_cshuffle.hpp | 2 +- ...ce_contraction_multiple_d_xdl_cshuffle.hpp | 2 +- ...evice_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp | 2 +- .../device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp | 5 +-- .../gpu/device/impl/device_gemm_dl.hpp | 5 +-- .../gpu/device/impl/device_gemm_dpp.hpp | 3 +- .../device_gemm_multiple_abd_xdl_cshuffle.hpp | 2 +- .../device/impl/device_gemm_multiple_d_dl.hpp | 12 ++---- ...gemm_multiple_d_layernorm_xdl_cshuffle.hpp | 2 +- ...emm_multiple_d_multiple_r_xdl_cshuffle.hpp | 2 +- .../device_gemm_multiple_d_wmma_cshuffle.hpp | 3 +- .../device_gemm_multiple_d_xdl_cshuffle.hpp | 2 +- .../gpu/device/impl/device_gemm_wmma.hpp | 3 +- .../gpu/device/impl/device_gemm_xdl.hpp | 3 +- .../device/impl/device_gemm_xdl_streamk.hpp | 4 +- .../device_gemm_xdl_waveletmodel_cshuffle.hpp | 2 +- ...ed_contraction_multiple_d_xdl_cshuffle.hpp | 2 +- ...conv_bwd_data_multiple_d_wmma_cshuffle.hpp | 3 +- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 2 +- .../device_grouped_conv_bwd_weight_dl.hpp | 5 +-- ..._grouped_conv_bwd_weight_wmma_cshuffle.hpp | 3 +- ...e_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 2 +- ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 12 ++---- ...ice_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp | 9 ++-- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 5 +-- ...fwd_multiple_d_multiple_r_xdl_cshuffle.hpp | 5 +-- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 3 +- .../device_grouped_gemm_multiple_d_dl.hpp | 37 ++++++++-------- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 2 +- .../device/impl/device_grouped_gemm_xdl.hpp | 2 +- .../impl/device_grouped_gemm_xdl_fixed_nk.hpp | 2 +- ...evice_grouped_gemm_xdl_splitk_cshuffle.hpp | 2 +- ...tk_contraction_multiple_d_xdl_cshuffle.hpp | 2 +- ...e_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp | 2 +- .../gpu/grid/gridwise_gemm_dpp.hpp | 3 +- ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 11 ++--- ...ultiple_d_xdl_cshuffle_lds_direct_load.hpp | 3 +- .../gridwise_gemm_reduce_xdl_cshuffle_v1.hpp | 2 +- .../gpu/grid/gridwise_gemm_wmma.hpp | 5 +-- .../grid/gridwise_gemm_xdl_cshuffle_v1.hpp | 4 +- .../grid/gridwise_gemm_xdl_cshuffle_v2.hpp | 4 +- ...ridwise_gemm_xdl_layernorm_cshuffle_v1.hpp | 2 +- .../grid/gridwise_gemm_xdlops_bwd_weight.hpp | 2 +- .../gridwise_gemm_xdlops_skip_b_lds_v1.hpp | 2 +- .../gpu/grid/gridwise_gemm_xdlops_streamk.hpp | 2 +- .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 4 +- .../gpu/grid/gridwise_gemm_xdlops_v2r4.hpp | 2 +- .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 2 +- .../gpu/grid/gridwise_gemm_xdlops_v3r1.hpp | 2 +- .../gpu/grid/gridwise_gemm_xdlops_v3r2.hpp | 2 +- .../gpu/grid/gridwise_gemm_xdlops_v3r3.hpp | 2 +- .../gpu/grid/gridwise_tensor_rearrange.hpp | 5 +-- include/ck/utility/amd_wmma.hpp | 23 +++++----- include/ck/utility/amd_xdlops.hpp | 22 ++++++---- include/ck/utility/type_convert.hpp | 30 +++++++------ .../test_grouped_convnd_bwd_weight.cpp | 5 +-- 69 files changed, 194 insertions(+), 194 deletions(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 88efb0277b..c93d1d0639 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -44,16 +44,30 @@ #define CK_USE_WAVES_PER_EU 0 #endif +// define general macros for various architectures +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#define __gfx94__ +#endif +#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) +#define __gfx101__ +#endif +#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) +#define __gfx103__ +#endif +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) +#define __gfx11__ +#endif + // buffer resource #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_BUFFER_RESOURCE_3RD_DWORD -1 #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) // for GPU code + defined(__gfx90a__) || defined(__gfx94__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 -#elif defined(__gfx1030__) // for GPU code +#elif defined(__gfx103__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 -#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code +#elif defined(__gfx11__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #endif @@ -61,12 +75,12 @@ #ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing #elif defined(__gfx803__) || defined(__gfx900__) // for GPU code #define CK_USE_AMD_V_MAC_F32 -#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // for GPU code +#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \ + defined(__gfx94__) // for GPU code #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8 -#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#elif defined(__gfx11__) #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8_GFX11 @@ -75,23 +89,22 @@ // MFMA instruction #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_MFMA -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) // for GPU code +#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code #define CK_USE_AMD_MFMA #endif -#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(defined(__gfx90a__) || defined(__gfx94__)) #define CK_USE_AMD_MFMA_BF16_1K_OP #endif -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) #define CK_USE_AMD_MFMA_GFX940 #endif // WMMA instruction #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_WMMA -#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code +#elif defined(__gfx11__) // for GPU code #define CK_USE_AMD_WMMA #endif @@ -107,15 +120,13 @@ // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) // for GPU code +#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 #endif -#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__)) // for GPU code +#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 #else #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index e8dabc9973..13e5268752 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -65,4 +65,23 @@ inline bool is_lds_direct_load_supported() ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; } +inline bool is_navi1_supported() +{ + return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" || + ck::get_device_name() == "gfx1012"; +} + +inline bool is_navi2_supported() +{ + return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" || + ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" || + ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036"; +} + +inline bool is_navi3_supported() +{ + return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; +} + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index 4d599e8017..b32f3a8daa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -770,8 +770,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102") + if(ck::is_navi3_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index 32c45bc57e..64aa398d53 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -57,7 +57,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index ba22cf0bf8..d06eab1264 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -75,7 +75,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index 3dbe8c6722..e950169ccf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -61,7 +61,7 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index 545d7e576f..d6b92bc97a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -84,7 +84,7 @@ __global__ void { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp index b51c600476..b01e029c03 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp @@ -70,9 +70,8 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ - defined(__gfx1101__) || defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -648,11 +647,8 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD(p_as_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 290abe221a..1f65afed3d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -54,7 +54,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index a8e586b20c..55cf8df272 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -56,7 +56,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index 3178f73f4b..d95671be7e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -1393,9 +1393,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl static bool IsSupportedArgument(const Argument& arg) { // check device - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || - ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102")) + if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() || + ck::is_navi3_supported())) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index 514aa4452e..bac124a2f1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -536,9 +536,8 @@ struct DeviceGemmDl : public DeviceGemm(p_as_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp index ad51096db7..8490fa52fd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp @@ -50,9 +50,8 @@ __global__ void const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ - defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); @@ -552,11 +551,8 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 916f29a904..bb2db930c8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -61,7 +61,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp index 44b3518e2c..fd90c7f1ea 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp @@ -484,8 +484,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index d98725cf9d..42f8daef71 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -53,7 +53,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index f64450b75f..98d14caa6d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -411,8 +411,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp index b008a6409e..5188ece333 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp @@ -184,8 +184,7 @@ struct DeviceGemmXdl : public DeviceGemm || is_same_v || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp index c8799e5154..51b8958d61 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp @@ -243,9 +243,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index 4823f5d489..cc022b89c5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -38,7 +38,7 @@ __global__ void const CDEElementwiseOperation cde_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index d66363c45c..0b3de153c3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -627,8 +627,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { // check device - if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102") + if(ck::is_navi3_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index a157d16181..c0fa9ad882 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -87,7 +87,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index a5f34f0b24..534467b959 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -48,9 +48,8 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ - defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \ - defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index dd591fb781..8850b13d0a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -698,8 +698,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { // check device - if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" || - get_device_name() == "gfx1102") + if(ck::is_navi3_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 468c92348e..26b0eae915 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -55,7 +55,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 6c8c8c2954..c3023301f3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -90,9 +90,8 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ - defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \ - defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -666,11 +665,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK namespace ctc = tensor_layout::convolution; // check device - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || - ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx908" || - ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" || - ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102" || - ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")) + if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || + ck::is_navi2_supported() || ck::is_navi3_supported())) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index f18fbcfe4b..d731e5ddac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -106,8 +106,8 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ - defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx11__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -601,9 +601,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd || is_same_v || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index 6665be7944..ab1c4fc08f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -156,7 +156,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -813,8 +813,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle return false; } } - else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940" || - get_device_name() == "gfx941" || get_device_name() == "gfx942") + else if(ck::is_lds_direct_load_supported()) { if constexpr(!(is_same_v || is_same_v || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 0050a5b281..ba2a4b0f7a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -531,8 +531,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle namespace ctc = tensor_layout::convolution; // check device - if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102") + if(ck::is_navi3_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index 0190b3cee6..6f7d7c3894 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -39,9 +39,8 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -668,26 +667,24 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp index 6eca77c89c..b473d7cbf2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp @@ -28,8 +28,7 @@ __global__ void #endif kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__) || defined(__gfx1100__) || \ - defined(__gfx1101__) || defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx103__) || defined(__gfx11__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 53b2169bc6..f514e3a119 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -54,8 +54,7 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -148,8 +147,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) // printf("entry kernel launch"); __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; @@ -244,8 +242,7 @@ __global__ void const CDEElementwiseOperation cde_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; GridwiseOp::template Run(p_a_grid, @@ -274,7 +271,7 @@ __global__ void ignore = b_element_op; ignore = cde_element_op; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx1100__ )) +#endif // end of if (defined(__gfx11__ )) } template < // DataType Family diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index b6a17e53a1..4cee1ed34b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -55,8 +55,7 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx940__) || \ - defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index d75b631e61..0e5777e561 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -55,7 +55,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index d8b31311b1..066cfc62f2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -49,8 +49,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -75,7 +74,7 @@ __global__ void ignore = b_element_op; ignore = c_element_op; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx1100__)) +#endif // end of if (defined(__gfx11__)) } template ( @@ -50,7 +50,7 @@ __global__ void typename GridwiseGemm::Problem problem) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, p_b_grid, p_c_grid, p_shared, problem); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp index 2ad2dd9915..db9625c6e6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -26,7 +26,7 @@ __global__ void kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -54,7 +54,7 @@ __global__ void typename GridwiseGemm::Problem problem) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index 013120c540..7f815de1f9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -58,7 +58,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; // TODO ANT: separate into MMA + Epilogue diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 06c87d1892..5617f67f8b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -167,7 +167,7 @@ __global__ void const CBlockClusterAdaptor c_block_cluster_adaptor) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp index b12bcee0f4..7c401a4957 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -45,7 +45,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp index 2b1814c03b..e9190dee29 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp @@ -38,7 +38,7 @@ __global__ void typename GridwiseGemm::Block2CTileMap block_mapping) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index 0b50601648..4f3caff248 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -39,7 +39,7 @@ __global__ void const CGridDesc_M_N c_grid_desc_m_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -70,7 +70,7 @@ __global__ void kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const auto a_grid_desc_k0_m_k1 = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index 19fbee727f..7d8e94c001 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -43,7 +43,7 @@ __global__ void const CBlockClusterAdaptor c_block_cluster_adaptor) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 87e1e0e8d9..6cbb834395 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -37,7 +37,7 @@ __global__ void const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index b766b70a67..15c64f2e47 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -47,7 +47,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index fbe4dd409d..e22bfb6439 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -50,7 +50,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index 1dc8d31efe..3da5e66018 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -54,7 +54,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp index 9535ca69a9..6772524e0a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp @@ -35,9 +35,8 @@ __global__ void const Block2ETileMap block_2_tile_map, const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ - defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) GridwiseTensorRearrangeKernel::Run(in_grid_desc, p_in_global, out_grid_desc, diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index dd7f0b770a..1bb0140f3e 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -9,6 +9,9 @@ // TODO: Add arch limitation namespace ck { +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) +#define __gfx11__ +#endif /********************************WAVE32 MODE***********************************************/ // src: fp16, dst: fp32 @@ -25,7 +28,7 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> // delete them. // amd_assembly_wmma_f32_16x16x16_f16_w32( // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); #else @@ -46,7 +49,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> template __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); @@ -71,7 +74,7 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); #else @@ -95,7 +98,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); @@ -117,7 +120,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> template __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( neg_a, @@ -145,7 +148,7 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16> template __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); #else @@ -166,7 +169,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16> template __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); @@ -191,7 +194,7 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); #else @@ -215,7 +218,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); @@ -237,7 +240,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> template __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64( neg_a, diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index afc066405e..0ee52b9570 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -4,6 +4,10 @@ #pragma once namespace ck { +// Define the common macro for MI300 models +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#define __gfx94__ +#endif // fp32 template @@ -341,7 +345,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> template __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c) { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else @@ -361,7 +365,7 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32> template __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(reg_a), @@ -393,7 +397,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> template __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( bit_cast(reg_a), bit_cast(reg_b), @@ -424,7 +428,7 @@ struct intrin_mfma_f32_32x32x16bf8bf8<32, 32> template __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( bit_cast(reg_a), @@ -456,7 +460,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> template __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( bit_cast(reg_a), bit_cast(reg_b), @@ -487,7 +491,7 @@ struct intrin_mfma_f32_32x32x16f8bf8<32, 32> template __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( bit_cast(reg_a), @@ -519,7 +523,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16> template __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( bit_cast(reg_a), bit_cast(reg_b), @@ -550,7 +554,7 @@ struct intrin_mfma_f32_32x32x16bf8f8<32, 32> template __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( bit_cast(reg_a), @@ -582,7 +586,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> template __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( bit_cast(reg_a), bit_cast(reg_b), diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 11db866152..6bbff98312 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -8,6 +8,10 @@ #include "ck/utility/random_gen.hpp" namespace ck { +// Define the common macro for MI300 models +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#define __gfx94__ +#endif // Convert X to Y, both X and Y are non-const data types. template (float x) { constexpr int seed = 42; uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float max_fp8 = 240.0f; x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); union @@ -133,7 +137,7 @@ inline __host__ __device__ f8_t f8_convert_sr(float x) template <> inline __host__ __device__ f8_t f8_convert_sr(half_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) // convert to float and use native converion return f8_convert_sr(type_convert(x)); #else @@ -154,7 +158,7 @@ inline __host__ __device__ bf8_t f8_convert_sr(float x) { constexpr int seed = 42; uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) union { float fval; @@ -180,7 +184,7 @@ inline __host__ __device__ bf8_t f8_convert_sr(float x) template <> inline __host__ __device__ bf8_t f8_convert_sr(half_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) // convert to float and use native converion return f8_convert_sr(type_convert(x)); #else @@ -203,7 +207,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x); template <> inline __host__ __device__ f8_t f8_convert_rne(float x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float max_fp8 = 240.0f; x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); union @@ -232,7 +236,7 @@ inline __host__ __device__ f8_t f8_convert_rne(float x) template <> inline __host__ __device__ f8_t f8_convert_rne(half_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) // convert to float and use native converion return f8_convert_rne(type_convert(x)); #else @@ -250,7 +254,7 @@ inline __host__ __device__ f8_t f8_convert_rne(half_t x) template <> inline __host__ __device__ bf8_t f8_convert_rne(float x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) union { float fval; @@ -277,7 +281,7 @@ inline __host__ __device__ bf8_t f8_convert_rne(float x) template <> inline __host__ __device__ bf8_t f8_convert_rne(half_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) // convert to float and use native converion return f8_convert_rne(type_convert(x)); #else @@ -306,7 +310,7 @@ inline __host__ __device__ f8_t type_convert(float x) template <> inline __host__ __device__ float type_convert(f8_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float fval; uint32_t i32val = static_cast(x); fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); @@ -321,7 +325,7 @@ inline __host__ __device__ float type_convert(f8_t x) template <> inline __host__ __device__ float2_t type_convert(f8x2_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) const auto i16val = bit_cast(x); return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0); #else @@ -363,7 +367,7 @@ inline __host__ __device__ f8_t type_convert(half_t x) template <> inline __host__ __device__ half_t type_convert(f8_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) // use native conversion to float and convert to fp16 return type_convert(type_convert(x)); #else @@ -387,7 +391,7 @@ inline __host__ __device__ bf8_t type_convert(float x) template <> inline __host__ __device__ float type_convert(bf8_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float fval; uint32_t i32val = static_cast(x); fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); @@ -414,7 +418,7 @@ inline __host__ __device__ bf8_t type_convert(half_t x) template <> inline __host__ __device__ half_t type_convert(bf8_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) // use native conversion to float and convert to fp16 return type_convert(type_convert(x)); #else diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 856f9fd15c..98e66c8a36 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -55,10 +55,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test } } - const bool is_navi3x = ck::get_device_name() == "gfx1100" || - ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102"; - if(is_navi3x) + if(ck::is_navi3_supported()) { // on navi3x only support for 3d is implemented if constexpr(NDimSpatial{} != 3) From f0dd1da088d060a0cc51f8b580073bbec1dc60fd Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:34:47 -0800 Subject: [PATCH 69/75] Delete any dangling images after building a new one. (#1155) * delete dangling docker images * fix groovy syntax * fix groovy syntax again * try a different way to delete dangling images --- Jenkinsfile | 1 + 1 file changed, 1 insertion(+) diff --git a/Jenkinsfile b/Jenkinsfile index 80e7b044f1..071ac31439 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -135,6 +135,7 @@ def buildDocker(install_prefix){ echo "Building image: ${image_name}" retimage = docker.build("${image_name}", dockerArgs + ' .') retimage.push() + sh 'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi' } else{ echo "Checking for image: ${image_name}" From 62996211076a70e7c2fecd3ce0a0ef2f49201236 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 6 Feb 2024 09:24:34 -0800 Subject: [PATCH 70/75] Bump rocm-docs-core from 0.33.0 to 0.33.1 in /docs/sphinx (#1158) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.33.0 to 0.33.1. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.33.0...v0.33.1) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 88142aa373..c80177bd30 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.33.0 +rocm-docs-core==0.33.1 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 12414c7470..a36f5e2be8 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -113,7 +113,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.33.0 +rocm-docs-core==0.33.1 # via -r requirements.in six==1.16.0 # via From 6951858221e03e321cb38c55b3a5ef3c68b5b79d Mon Sep 17 00:00:00 2001 From: Bartlomiej Wroblewski Date: Wed, 7 Feb 2024 01:08:34 +0100 Subject: [PATCH 71/75] Implement direct loads split-K GEMM kernel (#1137) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * WIP: Implement direct loads split-K GEMM kernel * Clean the review --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> Co-authored-by: Bartłomiej Kocot --- example/35_splitK_gemm/CMakeLists.txt | 3 + .../splitK_gemm_xdl_lds_direct_load_fp16.cpp | 82 ++ ...m_xdl_splitk_c_shuffle_lds_direct_load.hpp | 423 ++++++++ ...ultiple_d_xdl_cshuffle_lds_direct_load.hpp | 37 +- ...ise_gemm_xdlops_splitk_lds_direct_load.hpp | 962 ++++++++++++++++++ include/ck/utility/amd_lds.hpp | 43 + .../gpu/gemm_splitk.hpp | 8 +- .../gpu/gemm_splitk/CMakeLists.txt | 1 + ...ect_load_f16_f16_f16_mk_nk_mn_instance.cpp | 79 ++ 9 files changed, 1614 insertions(+), 24 deletions(-) create mode 100644 example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp create mode 100644 include/ck/utility/amd_lds.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index eff6b6f3fa..f98308d687 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -10,6 +10,9 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) + add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16) + add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp new file mode 100644 index 0000000000..97a3f89e5e --- /dev/null +++ b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +#define DIRECT_LOAD 1 + +#if DIRECT_LOAD +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp" +#else +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" +#endif + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +#if DIRECT_LOAD + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle_LdsDirectLoad + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 16, 16, 16, 1, 1, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, 1, 1, S<1, 32, 1, 4>, 4>; +// clang-format on + +#else + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#endif + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp new file mode 100644 index 0000000000..dd33e577bf --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp @@ -0,0 +1,423 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template + +struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using GridwiseGemm = GridwiseGemm_xdlops_splitk_lds_direct_load< + BlockSize, + ADataType, + BDataType, + AccDataType, + CDataType, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + NumGemmKPrefetchStage, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferSrcVectorDim, + ABlockTransferScalarPerVector, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferSrcVectorDim, + BBlockTransferScalarPerVector, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXDL, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer, + ComputeType>; + + struct Argument : public GridwiseGemm::Argument + { + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0Padded_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + MPadded_, + NPadded_, + KPadded_, + K0Padded_, + k_batch_), + a_element_op(a_element_op_), + b_element_op(b_element_op_), + c_element_op(c_element_op_) + { + } + + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CElementwiseOperation c_element_op; + }; + + using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Invoker + struct Invoker : public BaseInvoker + { + + void Print(const Argument& karg) { karg.Print(); } + + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + Print(karg); + } + + const auto kbatch = karg.k_batch; + + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid " + "setting"); + } + + const auto b2c_map = DefaultBlock2CTileMap{}; + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch); + const auto K0Padded = karg.K0Padded; + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(kbatch > 1) + hipGetErrorString(hipMemsetAsync(karg.p_c_grid, + 0, + karg.M * karg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + static_cast(karg), + b2c_map, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = + kernel_gemm_xdlops_splitk_lds_direct_load; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_splitk_lds_direct_load< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = + kernel_gemm_xdlops_splitk_lds_direct_load; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_splitk_lds_direct_load< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; + + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(karg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch) + { + return Argument(p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + GridwiseGemm::CalculateMPadded(M), + GridwiseGemm::CalculateNPadded(N), + GridwiseGemm::CalculateKPadded(K, KBatch), + GridwiseGemm::CalculateK0Padded(K, KBatch), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + GridwiseGemm::CalculateMPadded(M), + GridwiseGemm::CalculateNPadded(N), + GridwiseGemm::CalculateKPadded(K, KBatch), + GridwiseGemm::CalculateK0Padded(K, KBatch), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}, {PipelineVersion::v4, "v4"}}; + + // clang-format off + str << "DeviceGemmXdlSplitKCShuffle_LdsDirectLoad" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferScalarPerVector << ", " + << BBlockTransferScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer] << ", " + << "Prefetch: " + << NumGemmKPrefetchStage; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 4cee1ed34b..cd36b9e51a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/amd_lds.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -491,22 +492,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } - template - __device__ static auto AllocateBlockBuffers(void* p_shared, - int32_t num_elems, - int32_t offset_elems, - int32_t max_lds_align) - { - const int32_t single_buffer_offset = math::integer_least_multiple(num_elems, max_lds_align); - return generate_tuple( - [&](auto i) { - const int32_t local_offset = i * single_buffer_offset; - return make_dynamic_buffer( - static_cast(p_shared) + local_offset + offset_elems, num_elems); - }, - Number{}); - } - template ( - p_shared, a_block_desc_ak0_m_ak1.GetElementSpaceSize(), 0, max_lds_align); + const auto a_buffers_offset = 0; + auto a_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared, + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), + a_buffers_offset, + max_lds_align); const auto b_buffers_offset = a_block_space_size_aligned * NumGemmKPrefetchStage; auto b_block_buffers = - AllocateBlockBuffers(p_shared, - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), - b_buffers_offset, - max_lds_align); + ck::lds_utils::AllocateLdsBuffers( + p_shared, + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), + b_buffers_offset, + max_lds_align); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp new file mode 100644 index 0000000000..94306a4c95 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -0,0 +1,962 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_lds.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg, + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + + __shared__ uint8_t p_shared[shared_size]; + + GridwiseGemm::template Run( + karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); +#else + ignore = karg; + ignore = b2c_map; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_xdlops_splitk_lds_direct_load +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + static constexpr auto M01 = 1; + static constexpr auto N01 = 1; + + static constexpr auto gemm_padder = + tensor_operation::device::GemmPadder{ + MPerBlock, NPerBlock, K1* K0PerBlock}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + struct Argument : public ck::tensor_operation::device::BaseArgument + { + const FloatA* p_a_grid; + const FloatB* p_b_grid; + FloatC* p_c_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t K0Padded; + index_t k_batch; + + Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0Padded_, + index_t k_batch_) + : p_a_grid(p_a_grid_), + p_b_grid(p_b_grid_), + p_c_grid(p_c_grid_), + M(M_), + N(N_), + K(K_), + StrideA(StrideA_), + StrideB(StrideB_), + StrideC(StrideC_), + MPadded(MPadded_), + NPadded(NPadded_), + KPadded(KPadded_), + K0Padded(K0Padded_), + k_batch(k_batch_) + { + } + + void Print() const + { + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "K0Padded:" << K0Padded << ", " + << "KB:" << k_batch << "}" << std::endl; + } + }; + + __host__ __device__ static auto CalculateGridSize(const Argument& karg) + { + return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock), + math::integer_divide_ceil(karg.M, MPerBlock), + karg.k_batch); + } + + // prefer this to be called on host + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1) + { + // k_batch * k0 * k0_per_block * k1 + auto K_t = K_Batch * K0PerBlock * K1; + return (K + K_t - 1) / K_t * K0PerBlock; + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K0Padded = CalculateK0Padded(K, K_Batch); + return K_Batch * K0Padded * K1; + } + + __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, + index_t MPad, + index_t K, + index_t StrideA, + index_t KBatch, + index_t K0Padded, + index_t KPad) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, + index_t NPad, + index_t N, + index_t StrideB, + index_t KBatch, + index_t K0Padded, + index_t KPad) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); + + return math::max(NumGemmKPrefetchStage * (a_block_space_size + b_block_space_size) * + sizeof(ComputeType), + c_block_size * sizeof(FloatC)); + } + + __host__ __device__ static constexpr bool CheckValidity(const Argument& karg) + { + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.k_batch * K0PerBlock * K1; + if(!(karg.K % K_t == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { + return false; + } + } + else + { + if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { + return false; + } + } + + const auto num_k_loop = karg.K0Padded / K0PerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + return true; + } + + __host__ __device__ static auto GetKPad(index_t K, index_t KBatch) + { + const index_t K0Padded = + math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0Padded * K1; + return KPad; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded) + { + const index_t num_loop = K0Padded / K0PerBlock; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + // return block_id to C matrix tile idx (m0, n0, k_split) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap() + { + return BlockToCTileMap_3DGrid_KSplit(); + } + + using CGridDesc_M_N = remove_cvref_t; + using DefaultBlock2CTileMap = remove_cvref_t; + + template + __device__ static void Run(const Argument& karg, + void* __restrict__ p_shared_block, + const Block2CTileMap& block_2_ctile_map, + const AElementwiseOperation a_element_op = AElementwiseOperation{}, + const BElementwiseOperation b_element_op = BElementwiseOperation{}, + const CElementwiseOperation c_element_op = CElementwiseOperation{}) + { + // Elementwise operations are not supported for A and B, arguments left only for the API + // consistency. + (void)a_element_op; + (void)b_element_op; + + const FloatA* p_a_grid = karg.p_a_grid; + const FloatB* p_b_grid = karg.p_b_grid; + FloatC* p_c_grid = karg.p_c_grid; + const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded); + const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( + karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [KBatch, M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]); + const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto a_b_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto b_b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_K0_M_K1, + FloatA, + ComputeType, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_K0_N_K1, + FloatB, + ComputeType, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeType, // ComputeType A + ComputeType, // ComputeType B + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + K1, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + const auto a_buffers_offset = 0; + auto a_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared_block, + a_b_k0_m_k1_block_desc.GetElementSpaceSize(), + a_buffers_offset, + max_lds_align); + const auto b_buffers_offset = a_block_space_size * NumGemmKPrefetchStage; + auto b_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared_block, + b_b_k0_n_k1_block_desc.GetElementSpaceSize(), + b_buffers_offset, + max_lds_align); + + // gridwise GEMM pipeline + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) / + (K0PerBlock * K1)); + + const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + gridwise_gemm_pipeline.template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buffers, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buffers, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared_block), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/utility/amd_lds.hpp b/include/ck/utility/amd_lds.hpp new file mode 100644 index 0000000000..c218fded96 --- /dev/null +++ b/include/ck/utility/amd_lds.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_address_space.hpp" +#include "ck/utility/dynamic_buffer.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +namespace lds_utils { + +/** \brief Allocate a given number of buffers in LDS and return them as a tuple. + * + * \tparam DataType Data type of elements to be stored in LDS. + * \tparam NumBuffers Number of buffers to be allocated. + * \param lds_ptr Address of the beginning of LDS space. + * \param num_elems_per_buffer Number of elements to allocate per single buffer. + * \param start_offset_elems Number of elements to move from the start of LDS for the allocation of + * the first buffer. \param lds_alignment Alignment of every buffer allocation given as a number of + * elements. \return Tuple of dynamic buffers representing memory allocated in LDS. + */ +template +__device__ static auto AllocateLdsBuffers(void* lds_ptr, + int32_t num_elems_per_buffer, + int32_t start_offset_elems, + int32_t lds_alignment) +{ + const DataType* lds_start = static_cast(lds_ptr) + start_offset_elems; + const int32_t single_buffer_offset = + math::integer_least_multiple(num_elems_per_buffer, lds_alignment); + return generate_tuple( + [&](auto i) { + const int32_t local_offset = i * single_buffer_offset; + return make_dynamic_buffer(lds_start + local_offset, + num_elems_per_buffer); + }, + Number{}); +} + +} // namespace lds_utils +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp index 8ad6ddca9d..974da56649 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -36,6 +36,11 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( std::vector>>& instances); + +void add_device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); #endif #ifdef CK_ENABLE_FP32 void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( @@ -192,6 +197,7 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt index ec4c27598f..aaa0d7e960 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -8,6 +8,7 @@ list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_in device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..f0a54ee400 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#######################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#######################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#######################################| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 16, 128, 4, 16, 16, 16, 1, 2, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 1, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 4, 16, 16, 16, 1, 1, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 16, 16, 8, 16, 16, 16, 1, 1, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 32, 32, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 16, 64, 8, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 64, 4, 32, 16, 16, 1, 2, S<1, 2, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 4, 32, 16, 16, 1, 1, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 64, 16, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 16, 64, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 8, 16, 16, 16, 1, 1, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 128, 4, 32, 16, 16, 1, 2, S<1, 4, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 8, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 4, 32, 16, 16, 1, 1, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 64, 16, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 753cef783fe70b753f42bc6b4d008400980dc1b4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 6 Feb 2024 21:24:32 -0800 Subject: [PATCH 72/75] Bump rocm-docs-core from 0.33.1 to 0.33.2 in /docs/sphinx (#1160) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.33.1 to 0.33.2. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.33.1...v0.33.2) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index c80177bd30..a6b286b131 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.33.1 +rocm-docs-core==0.33.2 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index a36f5e2be8..4bbe95c934 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -113,7 +113,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.33.1 +rocm-docs-core==0.33.2 # via -r requirements.in six==1.16.0 # via From ba86eadce5ad22eca266d6af4fe4da1eee50fa79 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Wed, 7 Feb 2024 15:54:13 +0100 Subject: [PATCH 73/75] Add support for mixed-precision f16bf16_int8 gemm (#1127) --- .../element/binary_element_wise_operation.hpp | 83 ++++++- .../element/unary_element_wise_operation.hpp | 21 +- .../device_operation_instance_factory.hpp | 4 +- .../gpu/gemm_add.hpp | 114 +++++++++ .../gpu/gemm_add_fastgelu.hpp | 54 +++- .../gpu/gemm_add_relu.hpp | 116 +++++++++ .../gpu/gemm_add_silu.hpp | 116 +++++++++ .../gpu/gemm_add/CMakeLists.txt | 4 + ...bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp | 69 ++++++ ...le_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp | 69 ++++++ .../gpu/gemm_add_fastgelu/CMakeLists.txt | 2 + ...bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp | 73 ++++++ ...le_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp | 72 ++++++ .../gpu/gemm_add_relu/CMakeLists.txt | 4 + ...bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp | 71 ++++++ ...le_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp | 70 ++++++ .../gpu/gemm_add_silu/CMakeLists.txt | 4 + ...bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp | 71 ++++++ ...le_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp | 70 ++++++ .../profiler/profile_gemm_add_impl.hpp | 232 ++++++++++++++++++ .../profiler/profile_gemm_add_relu_impl.hpp | 232 ++++++++++++++++++ .../profiler/profile_gemm_add_silu_impl.hpp | 232 ++++++++++++++++++ profiler/src/CMakeLists.txt | 6 + profiler/src/profile_gemm_add.cpp | 139 +++++++++++ profiler/src/profile_gemm_add_fastgelu.cpp | 17 +- profiler/src/profile_gemm_add_relu.cpp | 139 +++++++++++ profiler/src/profile_gemm_add_silu.cpp | 139 +++++++++++ test/CMakeLists.txt | 1 + test/gemm_add/CMakeLists.txt | 11 + test/gemm_add/test_gemm_add.hpp | 72 ++++++ test/gemm_add/test_gemm_add_fastgelu.cpp | 41 ++++ test/gemm_add/test_gemm_add_relu.cpp | 41 ++++ test/gemm_add/test_gemm_add_silu.cpp | 41 ++++ 33 files changed, 2424 insertions(+), 6 deletions(-) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_add_silu.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_silu/device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_silu/device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp create mode 100644 profiler/include/profiler/profile_gemm_add_impl.hpp create mode 100644 profiler/include/profiler/profile_gemm_add_relu_impl.hpp create mode 100644 profiler/include/profiler/profile_gemm_add_silu_impl.hpp create mode 100644 profiler/src/profile_gemm_add.cpp create mode 100644 profiler/src/profile_gemm_add_relu.cpp create mode 100644 profiler/src/profile_gemm_add_silu.cpp create mode 100644 test/gemm_add/CMakeLists.txt create mode 100644 test/gemm_add/test_gemm_add.hpp create mode 100644 test/gemm_add/test_gemm_add_fastgelu.cpp create mode 100644 test/gemm_add/test_gemm_add_relu.cpp create mode 100644 test/gemm_add/test_gemm_add_silu.cpp diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index f0f3b0f167..95048469cd 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -75,6 +75,15 @@ struct Add y = ck::type_convert(y_tmp); } + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x0 + x2_tmp; + y = ck::type_convert(y_tmp); + } + template <> __host__ __device__ constexpr void operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const @@ -264,6 +273,14 @@ struct AddRelu y = a > 0.0f ? a : 0.0f; }; + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float a = x0 + type_convert(x1); + y = a > type_convert(0.0f) ? a : type_convert(0.0f); + }; + template <> __host__ __device__ constexpr void operator()(int& y, const int& x0, const int8_t& x1) const @@ -354,6 +371,70 @@ struct AddFastGelu e = type_convert(x1_f); } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& e, const float& c, const bhalf_t& d) const + { + const float x0_f = c + type_convert(d); + + float x1_f = 0; + + FastGelu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } +}; + +// E = Silu(C + D) +struct AddSilu +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d) const + { + const float x = c + d; + + Silu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const half_t& c, const half_t& d) const + { + const half_t x = c + d; + + Silu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const float& c, const half_t& d) const + { + const float x0_f = c + d; + + float x1_f = 0; + + Silu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& e, const float& c, const bhalf_t& d) const + { + const float x0_f = c + type_convert(d); + + float x1_f = 0; + + Silu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } }; } // namespace element_wise diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index eed60caef4..db89a79723 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -156,6 +156,12 @@ struct PassThrough y = type_convert(x); } + template <> + __host__ __device__ void operator()(bhalf_t& y, const int8_t& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(int8_t& y, const int32_t& x) const { @@ -551,6 +557,19 @@ struct Sigmoid }; }; +struct Silu +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v, + "Data type is not supported by this operation!"); + constexpr T one = type_convert(1); + y = x * (one / (one + ck::math::exp(-x))); + }; +}; + struct TanH { template diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index dc47c7ec1a..d88b9fd373 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -98,6 +98,8 @@ using Scale = ck::tensor_operation::element_wise::Scale; using Bilinear = ck::tensor_operation::element_wise::Bilinear; using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; +using AddSilu = ck::tensor_operation::element_wise::AddSilu; using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; using FastGelu = ck::tensor_operation::element_wise::FastGelu; using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp new file mode 100644 index 0000000000..030f3c2760 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>&); + +// GEMM + Add + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleD, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + Add>> +{ + using DeviceOp = DeviceGemmMultipleD, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + Add>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances(op_ptrs); + } + } +#endif + +#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances(op_ptrs); + } + } +#endif + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp index fd3550c2f0..555b52de75 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -68,6 +68,32 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_inst PassThrough, AddFastGelu>>>&); +void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>&); + // GEMM + Add + FastGelu template > op_ptrs; +#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif + +#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif + if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp new file mode 100644 index 0000000000..293e14b811 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>&); + +// GEMM + Add + Relu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleD, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddRelu>> +{ + using DeviceOp = DeviceGemmMultipleD, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddRelu>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif + +#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_silu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_silu.hpp new file mode 100644 index 0000000000..fbf45852ce --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_silu.hpp @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>&); + +// GEMM + Add + Silu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleD, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddSilu>> +{ + using DeviceOp = DeviceGemmMultipleD, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + PassThrough, + PassThrough, + AddSilu>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif + +#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + op_ptrs); + } + } +#endif + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt new file mode 100644 index 0000000000..fe85bb7ead --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt @@ -0,0 +1,4 @@ +add_instance_library(device_gemm_add_instance + device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..c489ef5e53 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_instances{}); + add_device_operation_instances( + instances, device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..7ec11f4a07 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add/device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_generic_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_generic_instances{}); + add_device_operation_instances( + instances, device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt index 0beb10e379..63b4a00c99 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt @@ -1,6 +1,8 @@ add_instance_library(device_gemm_add_fastgelu_instance + device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp + device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..baaeac618e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_instances = + std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_instances{}); + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..fc395b463b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_generic_instances = + std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_generic_instances{}); + add_device_operation_instances( + instances, device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt new file mode 100644 index 0000000000..969361de9a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt @@ -0,0 +1,4 @@ +add_instance_library(device_gemm_add_relu_instance + device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..8eac8a0505 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_instances = + std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_instances{}); + add_device_operation_instances( + instances, device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..f1269c3434 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_generic_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddRelu, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_generic_instances{}); + add_device_operation_instances( + instances, device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt new file mode 100644 index 0000000000..c10d4773a7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt @@ -0,0 +1,4 @@ +add_instance_library(device_gemm_add_silu_instance + device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp + device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_silu/device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_silu/device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..088b2fd7ee --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_silu/device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_instances = + std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddSilu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddSilu, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddSilu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, BF16, I8, F32, F32,BF16_Tuple, BF16, PassThrough, PassThrough, AddSilu, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_generic_instances{}); + add_device_operation_instances( + instances, device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_silu/device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_silu/device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp new file mode 100644 index 0000000000..41a45c6de6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_silu/device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// outout: e[m, n] +// input: a[m, k], b[k, n], d0[m, n], d1[m, n] +using device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_generic_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddSilu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddSilu, GemmMNKPadding, 1, 256, 16, 128, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddSilu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, I8, F32, F32, F16_Tuple, F16, PassThrough, PassThrough, AddSilu, GemmMNKPadding, 1, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_generic_instances{}); + add_device_operation_instances( + instances, device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_add_impl.hpp b/profiler/include/profiler/profile_gemm_add_impl.hpp new file mode 100644 index 0000000000..502d2b2951 --- /dev/null +++ b/profiler/include/profiler/profile_gemm_add_impl.hpp @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_add_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideE) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using Add = ck::tensor_operation::element_wise::Add; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = Add; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Add>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_add_relu_impl.hpp b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp new file mode 100644 index 0000000000..5d79a98c11 --- /dev/null +++ b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_relu.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_add_relu_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideE) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AddRelu = ck::tensor_operation::element_wise::AddRelu; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = AddRelu; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddRelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_add_silu_impl.hpp b/profiler/include/profiler/profile_gemm_add_silu_impl.hpp new file mode 100644 index 0000000000..e8a96208f6 --- /dev/null +++ b/profiler/include/profiler/profile_gemm_add_silu_impl.hpp @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_silu.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_add_silu_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideE) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AddRelu = ck::tensor_operation::element_wise::AddSilu; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = AddRelu; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddSilu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index e9cf6eecfb..c4b54d235f 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -43,7 +43,10 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) @@ -109,7 +112,10 @@ if(DL_KERNELS) endif() if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) diff --git a/profiler/src/profile_gemm_add.cpp b/profiler/src/profile_gemm_add.cpp new file mode 100644 index 0000000000..749966af1b --- /dev/null +++ b/profiler/src/profile_gemm_add.cpp @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_add_impl.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "gemm_add" +#define OP_DESC "GEMM+Add" + +using INT8 = int8_t; +using BF16 = ck::bhalf_t; + +int profile_gemm_add(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN_MN, // 0 + MK_NK_MN_MN, // 1 + KM_KN_MN_MN, // 2 + KM_NK_MN_MN, // 3 + }; + + enum struct MatrixDataType + { + F16_INT8_F16_F16, // 0 + BF16_INT8_BF16_BF16, // 1 + }; + + if(argc != 15) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: f16&i8 1: bf16&i8)\n"); + printf("arg3: matrix layout (0: E[m, n] = ReLU(A[m, k] * B[k, n] + D0[m, n]);\n"); + printf(" 1: E[m, n] = ReLU(A[m, k] * B[n, k] + D0[m, n]);\n"); + printf(" 2: E[m, n] = ReLU(A[k, m] * B[k, n] + D0[m, n]);\n"); + printf(" 3: E[m, n] = ReLU(A[k, m] * B[n, k] + D0[m, n]))\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideD0, StrideE\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideD0 = std::stoi(argv[13]); + const int StrideE = std::stoi(argv[14]); + + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + // using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto d0_type, + auto e_type, + auto a_layout, + auto b_layout, + auto d0_layout, + auto e_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using D0DataType = decltype(d0_type); + using EDataType = decltype(e_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using D0Layout = decltype(d0_layout); + using ELayout = decltype(e_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD0 = ck::is_same_v ? N : M; + const int DefaultStrideE = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_add_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD0 < 0) ? DefaultStrideD0 : StrideD0, + (StrideE < 0) ? DefaultStrideE : StrideE); + + return pass ? 0 : 1; + }; + + if(data_type == MatrixDataType::F16_INT8_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN) + { + return profile(F16{}, INT8{}, F32{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::BF16_INT8_BF16_BF16 && layout == MatrixLayout::MK_KN_MN_MN) + { + return profile(BF16{}, INT8{}, F32{}, BF16{}, BF16{}, Row{}, Row{}, Row{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_add); diff --git a/profiler/src/profile_gemm_add_fastgelu.cpp b/profiler/src/profile_gemm_add_fastgelu.cpp index a09bb8340d..f8335d8c05 100644 --- a/profiler/src/profile_gemm_add_fastgelu.cpp +++ b/profiler/src/profile_gemm_add_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -12,6 +12,9 @@ #define OP_NAME "gemm_add_fastgelu" #define OP_DESC "GEMM+Add+FastGeLU" +using INT8 = int8_t; +using BF16 = ck::bhalf_t; + int profile_gemm_add_fastgelu(int argc, char* argv[]) { enum struct MatrixLayout @@ -28,13 +31,15 @@ int profile_gemm_add_fastgelu(int argc, char* argv[]) F16_F16_F16_F16, // 1 BF16_BF16_BF16_BF16, // 2 INT8_INT8_INT8_INT8, // 3 + F16_INT8_F16_F16, // 4 + BF16_INT8_BF16_BF16, // 5 }; if(argc != 15) { // clang-format off printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); - printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f16&i8 5: bf16&i8)\n"); printf("arg3: matrix layout (0: E[m, n] = FastGeLU(A[m, k] * B[k, n] + D0[m, n]);\n"); printf(" 1: E[m, n] = FastGeLU(A[m, k] * B[n, k] + D0[m, n]);\n"); printf(" 2: E[m, n] = FastGeLU(A[k, m] * B[k, n] + D0[m, n]);\n"); @@ -135,6 +140,14 @@ int profile_gemm_add_fastgelu(int argc, char* argv[]) { return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Col{}, Col{}, Row{}, Row{}); } + else if(data_type == MatrixDataType::F16_INT8_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN) + { + return profile(F16{}, INT8{}, F32{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::BF16_INT8_BF16_BF16 && layout == MatrixLayout::MK_KN_MN_MN) + { + return profile(BF16{}, INT8{}, F32{}, BF16{}, BF16{}, Row{}, Row{}, Row{}, Row{}); + } else { std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/profiler/src/profile_gemm_add_relu.cpp b/profiler/src/profile_gemm_add_relu.cpp new file mode 100644 index 0000000000..025fddc82b --- /dev/null +++ b/profiler/src/profile_gemm_add_relu.cpp @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_add_relu_impl.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "gemm_add_relu" +#define OP_DESC "GEMM+Add+ReLU" + +using INT8 = int8_t; +using BF16 = ck::bhalf_t; + +int profile_gemm_add_relu(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN_MN, // 0 + MK_NK_MN_MN, // 1 + KM_KN_MN_MN, // 2 + KM_NK_MN_MN, // 3 + }; + + enum struct MatrixDataType + { + F16_INT8_F16_F16, // 0 + BF16_INT8_BF16_BF16, // 1 + }; + + if(argc != 15) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: f16&i8 1: bf16&i8)\n"); + printf("arg3: matrix layout (0: E[m, n] = ReLU(A[m, k] * B[k, n] + D0[m, n]);\n"); + printf(" 1: E[m, n] = ReLU(A[m, k] * B[n, k] + D0[m, n]);\n"); + printf(" 2: E[m, n] = ReLU(A[k, m] * B[k, n] + D0[m, n]);\n"); + printf(" 3: E[m, n] = ReLU(A[k, m] * B[n, k] + D0[m, n]))\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideD0, StrideE\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideD0 = std::stoi(argv[13]); + const int StrideE = std::stoi(argv[14]); + + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + // using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto d0_type, + auto e_type, + auto a_layout, + auto b_layout, + auto d0_layout, + auto e_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using D0DataType = decltype(d0_type); + using EDataType = decltype(e_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using D0Layout = decltype(d0_layout); + using ELayout = decltype(e_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD0 = ck::is_same_v ? N : M; + const int DefaultStrideE = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_add_relu_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD0 < 0) ? DefaultStrideD0 : StrideD0, + (StrideE < 0) ? DefaultStrideE : StrideE); + + return pass ? 0 : 1; + }; + + if(data_type == MatrixDataType::F16_INT8_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN) + { + return profile(F16{}, INT8{}, F32{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::BF16_INT8_BF16_BF16 && layout == MatrixLayout::MK_KN_MN_MN) + { + return profile(BF16{}, INT8{}, F32{}, BF16{}, BF16{}, Row{}, Row{}, Row{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_add_relu); diff --git a/profiler/src/profile_gemm_add_silu.cpp b/profiler/src/profile_gemm_add_silu.cpp new file mode 100644 index 0000000000..daaaef0fa2 --- /dev/null +++ b/profiler/src/profile_gemm_add_silu.cpp @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_add_silu_impl.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "gemm_add_silu" +#define OP_DESC "GEMM+Add+SiLU" + +using INT8 = int8_t; +using BF16 = ck::bhalf_t; + +int profile_gemm_add_silu(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN_MN, // 0 + MK_NK_MN_MN, // 1 + KM_KN_MN_MN, // 2 + KM_NK_MN_MN, // 3 + }; + + enum struct MatrixDataType + { + F16_INT8_F16_F16, // 0 + BF16_INT8_BF16_BF16, // 1 + }; + + if(argc != 15) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: f16&i8 1: bf16&i8)\n"); + printf("arg3: matrix layout (0: E[m, n] = ReLU(A[m, k] * B[k, n] + D0[m, n]);\n"); + printf(" 1: E[m, n] = ReLU(A[m, k] * B[n, k] + D0[m, n]);\n"); + printf(" 2: E[m, n] = ReLU(A[k, m] * B[k, n] + D0[m, n]);\n"); + printf(" 3: E[m, n] = ReLU(A[k, m] * B[n, k] + D0[m, n]))\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideD0, StrideE\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideD0 = std::stoi(argv[13]); + const int StrideE = std::stoi(argv[14]); + + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + // using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto d0_type, + auto e_type, + auto a_layout, + auto b_layout, + auto d0_layout, + auto e_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using D0DataType = decltype(d0_type); + using EDataType = decltype(e_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using D0Layout = decltype(d0_layout); + using ELayout = decltype(e_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD0 = ck::is_same_v ? N : M; + const int DefaultStrideE = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_add_silu_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD0 < 0) ? DefaultStrideD0 : StrideD0, + (StrideE < 0) ? DefaultStrideE : StrideE); + + return pass ? 0 : 1; + }; + + if(data_type == MatrixDataType::F16_INT8_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN) + { + return profile(F16{}, INT8{}, F32{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::BF16_INT8_BF16_BF16 && layout == MatrixLayout::MK_KN_MN_MN) + { + return profile(BF16{}, INT8{}, F32{}, BF16{}, BF16{}, Row{}, Row{}, Row{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_add_silu); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index fa5f8583af..a0f90256c0 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -122,6 +122,7 @@ add_subdirectory(space_filling_curve) add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) +add_subdirectory(gemm_add) add_subdirectory(gemm_layernorm) add_subdirectory(gemm_split_k) add_subdirectory(gemm_reduce) diff --git a/test/gemm_add/CMakeLists.txt b/test/gemm_add/CMakeLists.txt new file mode 100644 index 0000000000..7df3f90abc --- /dev/null +++ b/test/gemm_add/CMakeLists.txt @@ -0,0 +1,11 @@ +add_gtest_executable(test_gemm_add test_gemm_add.hpp) +target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance) + +add_gtest_executable(test_gemm_add_relu test_gemm_add_relu.cpp) +target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) + +add_gtest_executable(test_gemm_add_silu test_gemm_add_silu.cpp) +target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) + +add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu.cpp) +target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) diff --git a/test/gemm_add/test_gemm_add.hpp b/test/gemm_add/test_gemm_add.hpp new file mode 100644 index 0000000000..11d3d1c10a --- /dev/null +++ b/test/gemm_add/test_gemm_add.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_impl.hpp" + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using I8 = int8_t; +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +class TestGemmAdd : public ::testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using D0DataType = std::tuple_element_t<3, Tuple>; + using EDataType = std::tuple_element_t<4, Tuple>; + using ALayout = std::tuple_element_t<5, Tuple>; + using BLayout = std::tuple_element_t<6, Tuple>; + using D0Layout = std::tuple_element_t<7, Tuple>; + using ELayout = std::tuple_element_t<8, Tuple>; + + constexpr static auto ProfileGemmAddImpl = ck::profiler::profile_gemm_add_impl; + + virtual decltype(ProfileGemmAddImpl) GetImpl() { return ProfileGemmAddImpl; } + + void Run() + { + std::vector> lengths = { + {16, 32, 64}, {2048, 4096, 8192}, {2048, 1024, 16}}; + + bool all_success = true; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + int StrideA = ck::is_same_v ? K : M; + int StrideB = ck::is_same_v ? N : K; + int StrideD0 = ck::is_same_v ? N : M; + int StrideE = ck::is_same_v ? N : M; + + all_success = + all_success & + GetImpl()(true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideE); + } + + EXPECT_TRUE(all_success); + } +}; + +using KernelTypes = ::testing::Types, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmAdd, KernelTypes); +TYPED_TEST(TestGemmAdd, Test_BF16FP16_INT8) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_fastgelu.cpp b/test/gemm_add/test_gemm_add_fastgelu.cpp new file mode 100644 index 0000000000..c1c55140a0 --- /dev/null +++ b/test/gemm_add/test_gemm_add_fastgelu.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_fastgelu_impl.hpp" +#include "test_gemm_add.hpp" + +template +class TestGemmAddFastgelu : public TestGemmAdd +{ + private: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using D0DataType = std::tuple_element_t<3, Tuple>; + using EDataType = std::tuple_element_t<4, Tuple>; + using ALayout = std::tuple_element_t<5, Tuple>; + using BLayout = std::tuple_element_t<6, Tuple>; + using D0Layout = std::tuple_element_t<7, Tuple>; + using ELayout = std::tuple_element_t<8, Tuple>; + + constexpr static auto ProfileGemmAddFastgeluImpl = + ck::profiler::profile_gemm_add_fastgelu_impl; + + decltype(ProfileGemmAddFastgeluImpl) GetImpl() override { return ProfileGemmAddFastgeluImpl; } +}; + +using KernelTypes = ::testing::Types, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmAddFastgelu, KernelTypes); +TYPED_TEST(TestGemmAddFastgelu, Test_BF16FP16) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_relu.cpp b/test/gemm_add/test_gemm_add_relu.cpp new file mode 100644 index 0000000000..ba6aab36bd --- /dev/null +++ b/test/gemm_add/test_gemm_add_relu.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_relu_impl.hpp" +#include "test_gemm_add.hpp" + +template +class TestGemmAddRelu : public TestGemmAdd +{ + private: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using D0DataType = std::tuple_element_t<3, Tuple>; + using EDataType = std::tuple_element_t<4, Tuple>; + using ALayout = std::tuple_element_t<5, Tuple>; + using BLayout = std::tuple_element_t<6, Tuple>; + using D0Layout = std::tuple_element_t<7, Tuple>; + using ELayout = std::tuple_element_t<8, Tuple>; + + constexpr static auto ProfileGemmAddReluImpl = + ck::profiler::profile_gemm_add_relu_impl; + + decltype(ProfileGemmAddReluImpl) GetImpl() override { return ProfileGemmAddReluImpl; } +}; + +using KernelTypes = ::testing::Types, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmAddRelu, KernelTypes); +TYPED_TEST(TestGemmAddRelu, Test_BF16FP16_INT8) { this->Run(); } diff --git a/test/gemm_add/test_gemm_add_silu.cpp b/test/gemm_add/test_gemm_add_silu.cpp new file mode 100644 index 0000000000..d4dd6fa38b --- /dev/null +++ b/test/gemm_add/test_gemm_add_silu.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_add_silu_impl.hpp" +#include "test_gemm_add.hpp" + +template +class TestGemmAddSilu : public TestGemmAdd +{ + private: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using D0DataType = std::tuple_element_t<3, Tuple>; + using EDataType = std::tuple_element_t<4, Tuple>; + using ALayout = std::tuple_element_t<5, Tuple>; + using BLayout = std::tuple_element_t<6, Tuple>; + using D0Layout = std::tuple_element_t<7, Tuple>; + using ELayout = std::tuple_element_t<8, Tuple>; + + constexpr static auto ProfileGemmAddSiluImpl = + ck::profiler::profile_gemm_add_silu_impl; + + decltype(ProfileGemmAddSiluImpl) GetImpl() override { return ProfileGemmAddSiluImpl; } +}; + +using KernelTypes = ::testing::Types, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmAddSilu, KernelTypes); +TYPED_TEST(TestGemmAddSilu, Test_BF16FP16_INT8) { this->Run(); } From 1b0fbaebbb7a9235cf44c4b32a2698b97509901d Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 7 Feb 2024 12:47:12 -0800 Subject: [PATCH 74/75] Split-up instances to improve build times. (#1159) * split up splitk-gemm instances * clean up some unused variables * split the mk_kn_mn interwave splitk-gemm instances * split up f16_f16_f16 mk_nk_mn splitk gemm instances * fix clang format * fix function names * fix typo * split up the 2 largest fp16*fp8 splitk gemm instances * get rid of unused variables * split up the largest splitk-gemm fp8*fp16 instance file * split up the instances for xdl fp8 gemms * split the headers for f16 and i8 for wmmma convolution instances --- ..._shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp | 102 -------- ...uffle_fp8_fp8_fp8_mk_kn_mn_v1_instance.hpp | 59 +++++ ...fp8_fp8_mk_kn_mn_v1_interwave_instance.hpp | 59 +++++ ...uffle_fp8_fp8_fp8_mk_kn_mn_v2_instance.hpp | 59 +++++ .../tensor_operation_instance/gpu/gemm.hpp | 30 ++- .../gpu/gemm_splitk.hpp | 124 +++++++++- .../gpu/gemm_streamk.hpp | 14 +- ...ouped_conv_bwd_data_wmma_f16_instance.hpp} | 0 ...grouped_conv_bwd_data_wmma_i8_instance.hpp | 118 ++++++++++ .../gpu/gemm/CMakeLists.txt | 8 +- ..._fp8_fp8_mk_kn_mn_v1_default_instance.cpp} | 6 +- ...mk_kn_mn_v1_interwave_default_instance.cpp | 27 +++ ..._mk_kn_mn_v1_interwave_padded_instance.cpp | 27 +++ ...8_fp8_fp8_mk_kn_mn_v1_padded_instance.cpp} | 6 +- ...8_fp8_fp8_mk_kn_mn_v2_default_instance.cpp | 26 +++ ...p8_fp8_fp8_mk_kn_mn_v2_padded_instance.cpp | 26 +++ .../gpu/gemm_splitk/CMakeLists.txt | 27 ++- ...l_splitk_f16_f16_f16_mk_kn_mn_instance.cpp | 217 ------------------ ...plitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp | 95 ++++++++ ...f16_f16_mk_kn_mn_v1_interwave_instance.cpp | 81 +++++++ ..._kn_mn_v1_interwave_irregular_instance.cpp | 95 ++++++++ ...f16_f16_mk_kn_mn_v1_irregular_instance.cpp | 96 ++++++++ ...plitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp | 79 +++++++ ...f16_f16_mk_kn_mn_v2_irregular_instance.cpp | 95 ++++++++ ...l_splitk_f16_f16_f16_mk_nk_mn_instance.cpp | 202 ---------------- ...plitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp | 90 ++++++++ ...f16_f16_mk_nk_mn_v1_interwave_instance.cpp | 76 ++++++ ..._nk_mn_v1_interwave_irregular_instance.cpp | 95 ++++++++ ...f16_f16_mk_nk_mn_v1_irregular_instance.cpp | 95 ++++++++ ...plitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp | 76 ++++++ ...f16_f16_mk_nk_mn_v2_irregular_instance.cpp | 95 ++++++++ ...l_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp | 153 ------------ ...16_fp8_f16_mk_kn_mn_irregular_instance.cpp | 61 +++++ ...plitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp | 96 ++++++++ ...fp8_f16_mk_kn_mn_v1_interwave_instance.cpp | 82 +++++++ ...plitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp | 80 +++++++ ...litk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp} | 34 +-- ...fp8_f16_mk_nk_mn_v1_interwave_instance.cpp | 77 +++++++ ...plitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp | 77 +++++++ ...l_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp | 135 ----------- ...plitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp | 95 ++++++++ ...f16_f16_mk_kn_mn_v1_interwave_instance.cpp | 81 +++++++ ...plitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp | 79 +++++++ ...gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp | 2 +- ...ta_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 2 +- ..._gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp | 2 +- ...ata_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp | 2 +- ...nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp | 2 +- ...ta_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 2 +- ..._nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp | 2 +- ...ata_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp | 2 +- ...hwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp | 2 +- ...wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp | 2 +- ...dhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp | 2 +- ..._wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp | 2 +- ...wgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp | 2 +- ...wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 2 +- ...hwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp | 2 +- ..._wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp | 2 +- 59 files changed, 2401 insertions(+), 886 deletions(-) delete mode 100644 library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_instance.hpp rename library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/{device_grouped_conv_bwd_data_wmma_instance.hpp => device_grouped_conv_bwd_data_wmma_f16_instance.hpp} (100%) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_i8_instance.hpp rename library/src/tensor_operation_instance/gpu/gemm/{device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp => device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_default_instance.cpp} (74%) create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_padded_instance.cpp rename library/src/tensor_operation_instance/gpu/gemm/{device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp => device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_padded_instance.cpp} (75%) create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_padded_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_irregular_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_interwave_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp rename library/src/tensor_operation_instance/gpu/gemm_splitk/{device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp => device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp} (50%) create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_interwave_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_interwave_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp deleted file mode 100644 index 005cec94ec..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp +++ /dev/null @@ -1,102 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#ifdef CK_ENABLE_FP8 -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F32 = float; -using F8 = f8_t; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = std::tuple< - // clang-format off - //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| - //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | - //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | - //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // pipeline v1, 1 wave - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 64, 16, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 64, 16, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1> -#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES - // pipeline v1, 2 waves - , - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 64, 16, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 64, 16, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1> - -#endif -#if 0 - //CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES - // pipeline v2, 1 wave - , - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 64, 16, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 64, 16, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, - DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2> -#endif - // clang-format on - >; - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_instance.hpp new file mode 100644 index 0000000000..ca1d56c769 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_instance.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#ifdef CK_ENABLE_FP8 +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F8 = f8_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 64, 16, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 64, 16, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_instance.hpp new file mode 100644 index 0000000000..7c215eb212 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_instance.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#ifdef CK_ENABLE_FP8 +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F8 = f8_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 2 waves + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 64, 16, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 64, 16, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Interwave, PipelineVersion::v1> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_instance.hpp new file mode 100644 index 0000000000..ee361bae51 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_instance.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#ifdef CK_ENABLE_FP8 +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F8 = f8_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v2, 1 wave + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 64, 16, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 64, 16, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp index 626dd7f00a..31e5b72ea1 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -345,11 +345,27 @@ void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( std::vector>>& instances); -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances( +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances( std::vector>>& instances); -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances( +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances( std::vector>>& instances); @@ -579,8 +595,14 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances(op_ptrs); - add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances( + op_ptrs); + add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances( + op_ptrs); + add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp index 974da56649..ebbe7c7211 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp @@ -27,12 +27,62 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances( DeviceGemmSplitK>>& instances); -void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instances( std::vector>>& instances); -void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instances( std::vector>>& instances); @@ -74,7 +124,17 @@ void add_device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instances( DeviceGemmSplitK>>& instances); -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_v1_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_v1_interwave_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_v2_instances( std::vector>>& instances); @@ -94,12 +154,37 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances( DeviceGemmSplitK>>& instances); -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_v1_instances( std::vector>>& instances); -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_v1_interwave_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_v2_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_interwave_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances( std::vector>>& instances); @@ -191,12 +276,24 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instances( + op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instances( + op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instances(op_ptrs); add_device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && @@ -218,7 +315,9 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_v1_instances(op_ptrs); + add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_v1_interwave_instances(op_ptrs); + add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_v2_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) @@ -242,12 +341,17 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_v1_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_v1_interwave_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_v2_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_interwave_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp index 730785f702..0e6b40b2e7 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_streamk.hpp @@ -83,12 +83,22 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular__instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_f16_instance.hpp similarity index 100% rename from library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp rename to library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_f16_instance.hpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_i8_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_i8_instance.hpp new file mode 100644 index 0000000000..5db8226e11 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_i8_instance.hpp @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using I8 = int8_t; +using I32 = int32_t; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdData1x1S1P0 = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_data_wmma_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + // blocksize=256 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // blocksize=128 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // blocksize=64 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + // blocksize=32 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_wmma_i8_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 32, 1, 4>, 1>, + // blocksize=256 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // blocksize=128 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // blocksize=64 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + // blocksize=32 + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 3532c3f4ba..3d243e3d56 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -101,8 +101,12 @@ list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp) list(APPEND GEMM_INSTANCES - device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp - device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_default_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_default_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_default_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_padded_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_padded_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_padded_instance.cpp device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_default_instance.cpp similarity index 74% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_default_instance.cpp index baa76a74af..79f01e77e9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_default_instance.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_instance.hpp" #ifdef CK_ENABLE_FP8 namespace ck { @@ -11,12 +11,12 @@ namespace instance { static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_default_instances( +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances( std::vector>>& instances) { add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances{}); + instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_default_instance.cpp new file mode 100644 index 0000000000..6ca2790e04 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_default_instance.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_instance.hpp" + +#ifdef CK_ENABLE_FP8 +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_padded_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_padded_instance.cpp new file mode 100644 index 0000000000..34195f4720 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_padded_instance.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_interwave_instance.hpp" + +#ifdef CK_ENABLE_FP8 +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_padded_instance.cpp similarity index 75% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_padded_instance.cpp index f16809db28..a7d3e5febd 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_padded_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_padded_instance.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v1_instance.hpp" #ifdef CK_ENABLE_FP8 namespace ck { @@ -11,12 +11,12 @@ namespace instance { static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_padded_instances( +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances( std::vector>>& instances) { add_device_operation_instances( - instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances{}); + instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_default_instance.cpp new file mode 100644 index 0000000000..f7b720a610 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_default_instance.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_instance.hpp" + +#ifdef CK_ENABLE_FP8 +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_padded_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_padded_instance.cpp new file mode 100644 index 0000000000..d8dfb00f63 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_padded_instance.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_v2_instance.hpp" + +#ifdef CK_ENABLE_FP8 +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt index aaa0d7e960..a4d23914dd 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -4,17 +4,34 @@ list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_in device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_interwave_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp - device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_interwave_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_irregular_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_interwave_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp deleted file mode 100644 index 45096f659f..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,217 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding; -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_generic_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 2> - // clang-format on - >; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - //PipelineVersion::v1 - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipelineVersion::v1>, - - //PipelineVersion::v1; interwave - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - - - //PipelineVersion::v2 - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2> - // clang-format on - >; - -template -using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, - - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, - - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, - - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, - - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_generic_instances{}); - - add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{}); - - add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{}); - - add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances{}); - - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< - GemmDefault, - ck::PipelineVersion::v1, - ck::LoopScheduler::Default>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< - GemmDefault, - ck::PipelineVersion::v2, - ck::LoopScheduler::Default>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< - GemmDefault, - ck::PipelineVersion::v1, - ck::LoopScheduler::Interwave>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< - GemmKPadding, - ck::PipelineVersion::v1, - ck::LoopScheduler::Default>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< - GemmKPadding, - ck::PipelineVersion::v2, - ck::LoopScheduler::Default>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< - GemmKPadding, - ck::PipelineVersion::v1, - ck::LoopScheduler::Interwave>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< - GemmMNKPadding, - ck::PipelineVersion::v1, - ck::LoopScheduler::Default>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< - GemmMNKPadding, - ck::PipelineVersion::v2, - ck::LoopScheduler::Default>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< - GemmMNKPadding, - ck::PipelineVersion::v1, - ck::LoopScheduler::Interwave>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..7ee911e63b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_generic_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 2> + // clang-format on + >; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v1 + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_generic_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp new file mode 100644 index 0000000000..efc7a7ebfd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_iw_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v1; interwave + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_iw_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_iw_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_iw_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instance.cpp new file mode 100644 index 0000000000..25dceab0fa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp new file mode 100644 index 0000000000..6a323d323f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp new file mode 100644 index 0000000000..c4f8f67145 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v2 + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instance.cpp new file mode 100644 index 0000000000..52f40d346b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp deleted file mode 100644 index b22f4a3beb..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp +++ /dev/null @@ -1,202 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding; -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_generic_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 2, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2, F16, PipelineVersion::v1> - // clang-format on - >; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - //PipelineVersion::v1 - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, - - //PipelineVersion::v1; interwave - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - - //PipelineVersion::v2 - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2> - // clang-format on - >; -template -using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, - - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, - - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, - - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, - - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_generic_instances{}); - - add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{}); - - add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{}); - - add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances{}); - - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< - GemmDefault, - ck::PipelineVersion::v1, - ck::LoopScheduler::Default>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< - GemmDefault, - ck::PipelineVersion::v2, - ck::LoopScheduler::Default>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< - GemmDefault, - ck::PipelineVersion::v1, - ck::LoopScheduler::Interwave>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< - GemmKPadding, - ck::PipelineVersion::v1, - ck::LoopScheduler::Default>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< - GemmKPadding, - ck::PipelineVersion::v2, - ck::LoopScheduler::Default>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< - GemmKPadding, - ck::PipelineVersion::v1, - ck::LoopScheduler::Interwave>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< - GemmMNKPadding, - ck::PipelineVersion::v1, - ck::LoopScheduler::Default>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< - GemmMNKPadding, - ck::PipelineVersion::v2, - ck::LoopScheduler::Default>{}); - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< - GemmMNKPadding, - ck::PipelineVersion::v1, - ck::LoopScheduler::Interwave>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp new file mode 100644 index 0000000000..b8fc23a54b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_generic_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 2, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2, F16, PipelineVersion::v1> + // clang-format on + >; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v1 + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_generic_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp new file mode 100644 index 0000000000..2855235f97 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_iw_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v1; interwave + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_iw_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_iw_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_iw_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instance.cpp new file mode 100644 index 0000000000..b65c8c6a81 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instance.cpp new file mode 100644 index 0000000000..5a5b8ce82d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp new file mode 100644 index 0000000000..d487e8dd82 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v2 + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instance.cpp new file mode 100644 index 0000000000..a721f4bc96 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 4, 8, 16, 16, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 512, 16, 4, 8, 16, 16, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 16, 4, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 8, 16, 16, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 8, 8, 16, 16, 1, 8, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 8, 8, 16, 16, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 16>, 4, F16, PipVer, LoopSche>, + + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 8, 8, 16, 16, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 16, 8, 8, 16, 16, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 4, F16, PipVer, LoopSche>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 8, 8, 16, 16, 4, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 4, F16, PipVer, LoopSche> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmDefault, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances< + GemmMNKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp deleted file mode 100644 index 150ccf1a90..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,153 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F8 = ck::f8_t; -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_generic_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 2> - // clang-format on - >; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - //PipelineVersion::v1 - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipelineVersion::v1>, - - //PipelineVersion::v1; interwave - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - - //PipelineVersion::v2 - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2> - // clang-format on - >; - -template -using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave> - // clang-format on - >; - -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_generic_instances{}); - - add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances{}); - - add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances{}); - - add_device_operation_instances( - instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances{}); - - add_device_operation_instances( - instances, - device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_irregular_instance.cpp new file mode 100644 index 0000000000..b66eef2834 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_irregular_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 256, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..0900ba02ce --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_generic_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 2> + // clang-format on + >; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v1 + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipelineVersion::v1> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_v1_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_generic_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_interwave_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_interwave_instance.cpp new file mode 100644 index 0000000000..dec25ef93b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_interwave_instance.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v1; interwave + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_v1_interwave_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp new file mode 100644 index 0000000000..d7f6433012 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v2 + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_v2_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp similarity index 50% rename from library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp index c1e43937a6..302e2bc250 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp @@ -63,41 +63,11 @@ using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple< DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, - - //PipelineVersion::v1; interwave - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - - //PipelineVersion::v2 - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2> + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1> // clang-format on >; -void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_instances( std::vector>>& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_interwave_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_interwave_instance.cpp new file mode 100644 index 0000000000..48dd5dbada --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_interwave_instance.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v1; interwave + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_interwave_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp new file mode 100644 index 0000000000..b8029bc075 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v2 + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp deleted file mode 100644 index 49e904e990..0000000000 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,135 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F8 = ck::f8_t; -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_generic_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 2> - // clang-format on - >; - -// Compilation parameters for a[m, k] * b[k, n] = c[m, n] -template -using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances = std::tuple< - // clang-format off - //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| - //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| - //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - //PipelineVersion::v1 - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipelineVersion::v1>, - - //PipelineVersion::v1; interwave - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave>, - - - - //PipelineVersion::v2 - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, - DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2> - // clang-format on - >; -void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_generic_instances{}); - - add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances{}); - - add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances{}); - - add_device_operation_instances( - instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..c67548cbef --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_generic_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 2> + // clang-format on + >; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v1 + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipelineVersion::v1> + // clang-format on + >; +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_v1_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_generic_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_interwave_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_interwave_instance.cpp new file mode 100644 index 0000000000..204fb1e386 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_interwave_instance.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v1; interwave + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 32, 4, 8, 32, 32, 1, 1, S<1, 2, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v1, LoopScheduler::Interwave>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 32, 4, 8, 16, 16, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 4, F16, PipelineVersion::v1, LoopScheduler::Interwave> + // clang-format on + >; +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_v1_interwave_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp new file mode 100644 index 0000000000..533f6cbd94 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //PipelineVersion::v2 + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, PipelineVersion::v2>, + DeviceGemmXdlSplitKCShuffle< F8, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16, PipelineVersion::v2> + // clang-format on + >; +void add_device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_v2_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances{}); + + add_device_operation_instances( + instances, device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp index ac2ba91b63..3afba67be8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_f16_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp index 39af70f623..6f45474526 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_f16_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp index 1de9f7a95c..1e60075d8a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_i8_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp index 8eb6558a2a..b2b7dc7eae 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_i8_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp index f46cdf7f18..4efc65c215 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_f16_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp index 2f4659524f..ffb64e6d73 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_f16_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp index 789b80b8ac..70623a1222 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_i8_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp index 71394762b4..aedcd698f4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_i8_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp index ba2f6bdc0a..599fa38305 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_f16_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp index 26403bf0e5..8438a94e3f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_f16_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp index 9e453a36c5..b3440f027c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_i8_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp index 6f7d4f79d1..9693ba9a56 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_i8_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp index 4475a0e456..a10e5a7128 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_f16_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp index f3941636f8..ce982aa65c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_f16_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp index b8479c9c9d..7b53e004f7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_i8_instance.hpp" namespace ck { namespace tensor_operation { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp index d8a74e744b..f293164c8a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_i8_instance.hpp" namespace ck { namespace tensor_operation { From 1f306024d01ed4ebf66f226c882fdcaa7ae207a7 Mon Sep 17 00:00:00 2001 From: Lakhinder Walia <139581206+lakhinderwalia@users.noreply.github.com> Date: Wed, 7 Feb 2024 19:24:51 -0800 Subject: [PATCH 75/75] fast_gelu: minor code reorg to enhance ref & gpu performance (#1162) --- .../element/unary_element_wise_operation.hpp | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index db89a79723..70c72bf768 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -458,27 +458,29 @@ struct FastGelu template <> __host__ void operator()(float& y, const float& x) const { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = exp(-u); - const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); - - y = x * cdf; + // const float u = -2.f * x * (0.035677f * x * x + 0.797885f); + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u = x * (c1 * x * x + c2); + const float emu = exp(u); + y = x / (1.f + emu); } // device code, use lower precision "__expf" and "rcp" template <> __device__ void operator()(float& y, const float& x) const { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = __expf(-u); + // const float u = 2.f * x * (0.035677f * x * x + 0.797885f); + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u = x * (c1 * x * x + c2); + const float emu = __expf(u); #if !CK_WORKAROUND_SWDEV_383542 - const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f); + y = x * __frcp_rn(1.f + emu); #else - const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f); + y = x * __ocml_native_recip_f32(1.f + emu); #endif - - y = x * cdf; } template <>