diff --git a/.azuredevops/rocm-ci.yml b/.azuredevops/rocm-ci.yml new file mode 100644 index 0000000000..8c52856755 --- /dev/null +++ b/.azuredevops/rocm-ci.yml @@ -0,0 +1,42 @@ +resources: + repositories: + - repository: pipelines_repo + type: github + endpoint: ROCm + name: ROCm/ROCm + +variables: +- group: common +- template: /.azuredevops/variables-global.yml@pipelines_repo + +trigger: + batch: true + branches: + include: + - develop + paths: + exclude: + - .github + - docs + - '.*.y*ml' + - '*.md' + - Jenkinsfile + - LICENSE + +pr: + autoCancel: true + branches: + include: + - develop + paths: + exclude: + - .github + - docs + - '.*.y*ml' + - '*.md' + - Jenkinsfile + - LICENSE + drafts: false + +jobs: + - template: ${{ variables.CI_COMPONENT_PATH }}/composable_kernel.yml@pipelines_repo diff --git a/CMakeLists.txt b/CMakeLists.txt index c23746e7f3..3f9e445837 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,7 +23,7 @@ endif() set(version 1.1.0) # Check support for CUDA/HIP in Cmake -project(composable_kernel VERSION ${version} LANGUAGES CXX) +project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) include(CTest) find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED) @@ -112,7 +112,7 @@ message("checking which targets are supported") #Setting GPU_TARGETS on command line will override this list if(NOT PROFILER_ONLY) rocm_check_target_ids(DEFAULT_GPU_TARGETS - TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") + TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") else() add_definitions(-DPROFILER_ONLY) set(GPU_TARGETS "" CACHE STRING "" FORCE) @@ -135,12 +135,10 @@ endif() message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}") -set(AMDGPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) - if(GPU_TARGETS) message("Building CK for the following targets: ${GPU_TARGETS}") else() - message("Building CK for the following targets: ${AMDGPU_TARGETS}") + message("Building CK for the default targets: ${DEFAULT_GPU_TARGETS}") endif() if (GPU_TARGETS) @@ -225,7 +223,13 @@ link_libraries(Threads::Threads) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) -message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") +message("CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}") + +## HIP +set(CMAKE_HIP_PLATFORM amd) +set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER}) +set(CMAKE_HIP_EXTENSIONS ON) +message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}") ## OpenMP if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") diff --git a/client_example/07_grouped_convnd_fwd/CMakeLists.txt b/client_example/07_grouped_convnd_fwd/CMakeLists.txt index 710eca9f49..e8c046ff44 100644 --- a/client_example/07_grouped_convnd_fwd/CMakeLists.txt +++ b/client_example/07_grouped_convnd_fwd/CMakeLists.txt @@ -4,4 +4,22 @@ if(GPU_TARGETS MATCHES "gfx9") add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations) -endif() \ No newline at end of file + + if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) + add_executable(client_grouped_conv3d_fwd_fp8 grouped_conv3d_fwd_fp8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations) + endif() + + if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + add_executable(client_grouped_conv3d_fwd_bf8 grouped_conv3d_fwd_bf8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations) + endif() + + if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + add_executable(client_grouped_conv3d_fwd_fp8_bf8 grouped_conv3d_fwd_fp8_bf8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) + + add_executable(client_grouped_conv3d_fwd_bf8_fp8 grouped_conv3d_fwd_bf8_fp8.cpp) + target_link_libraries(client_grouped_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) + endif() +endif() diff --git a/client_example/07_grouped_convnd_fwd/common.hpp b/client_example/07_grouped_convnd_fwd/common.hpp new file mode 100644 index 0000000000..729af0b88b --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/common.hpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths) +{ + // 2 * G * N * K * C * * + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return static_cast(2) * G * N * K * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd(std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + std::size_t flop = GetFlops(out_lengths, wei_lengths); + std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + PassThrough, + AComputeType, + BComputeType>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp index 4983ac33c3..d3a3111e94 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp @@ -1,17 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include -#include -#include +#include "common.hpp" #include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -31,199 +24,16 @@ static constexpr ck::index_t X = 3; static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wo = 28; -struct SimpleDeviceMem -{ - SimpleDeviceMem() = delete; - - SimpleDeviceMem(std::size_t mem_size) : p_mem_{} - { - (void)hipMalloc(static_cast(&p_mem_), mem_size); - } - - void* GetDeviceBuffer() { return p_mem_; } - - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } - - void* p_mem_; -}; - int main() { - std::array in_lengths{G, N, Wi, C}; - std::array in_strides{0, 0, 0, 1}; - - std::array wei_lengths{G, K, X, C}; - std::array wei_strides{0, 0, 0, 1}; - - std::array out_lengths{G, N, Wo, K}; - std::array out_strides{0, 0, 0, 1}; - - std::partial_sum(rbegin(in_lengths), - std::prev(rend(in_lengths)), - std::next(rbegin(in_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(wei_lengths), - std::prev(rend(wei_lengths)), - std::next(rbegin(wei_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(out_lengths), - std::prev(rend(out_lengths)), - std::next(rbegin(out_strides)), - std::multiplies<>{}); - - // transpose GNWC/GKXC/GNWK to GNCW/GKCX/GNCW - std::rotate(rbegin(in_lengths), - std::next(rbegin(in_lengths)), - std::next(rbegin(in_lengths), NumDimSpatial + 1)); - std::rotate(rbegin(in_strides), - std::next(rbegin(in_strides)), - std::next(rbegin(in_strides), NumDimSpatial + 1)); - std::rotate(rbegin(wei_lengths), - std::next(rbegin(wei_lengths)), - std::next(rbegin(wei_lengths), NumDimSpatial + 1)); - std::rotate(rbegin(wei_strides), - std::next(rbegin(wei_strides)), - std::next(rbegin(wei_strides), NumDimSpatial + 1)); - std::rotate(rbegin(out_lengths), - std::next(rbegin(out_lengths)), - std::next(rbegin(out_lengths), NumDimSpatial + 1)); - std::rotate(rbegin(out_strides), - std::next(rbegin(out_strides)), - std::next(rbegin(out_strides), NumDimSpatial + 1)); - - std::array filter_strides{1}; - std::array filter_dilations{1}; - std::array input_left_pads{1}; - std::array input_right_pads{1}; - - SimpleDeviceMem in(sizeof(InDataType) * G * N * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * G * N * Wo * K); - - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - PassThrough>; - - // get device op instances - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - - std::string best_op_name; - int best_op_id = -1; - float best_avg_time = std::numeric_limits::max(); - float best_gb_per_sec = 0; - float best_tflops = 0; - - // profile device operation instances - std::cout << "Run all instances and do timing" << std::endl; - - for(int i = 0; i < op_ptrs.size(); ++i) - { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - std::string op_name = op_ptr->GetTypeString(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - - std::size_t flop = std::size_t(2) * G * N * K * C * Wo * X; - std::size_t num_bytes = sizeof(InDataType) * G * N * Wi * C + - sizeof(WeiDataType) * G * K * X * C + - sizeof(OutDataType) * G * N * Wo * K; - - float tflops = static_cast(flop) / 1.E9 / avg_time; - float gb_per_sec = num_bytes / 1.E6 / avg_time; - - std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; - - if(tflops > best_tflops) - { - best_op_id = i; - best_op_name = op_name; - best_avg_time = avg_time; - best_gb_per_sec = gb_per_sec; - best_tflops = tflops; - } - } - else - { - std::cerr << op_name << " does not support this problem" << std::endl; - } - } - - if(best_op_id < 0) - { - std::cerr << "no suitable instance" << std::endl; - return EXIT_FAILURE; - } - - std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops - << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; - - // run the best intance - { - auto& op_ptr = op_ptrs[best_op_id]; - std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() - << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - } - - std::cout << "Done" << std::endl; - } + return run_grouped_conv_fwd({N, Wi, G, C}, {G, K, X, C}, {N, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; } diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp index 9383350629..fb8a410ab3 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp @@ -1,17 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include -#include -#include +#include "common.hpp" #include "ck/ck.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -34,167 +27,16 @@ static constexpr ck::index_t Wi = 28; // input W static constexpr ck::index_t Ho = 28; // output H static constexpr ck::index_t Wo = 28; // output W -struct SimpleDeviceMem -{ - SimpleDeviceMem() = delete; - - SimpleDeviceMem(std::size_t mem_size) : p_mem_{} - { - (void)hipMalloc(static_cast(&p_mem_), mem_size); - } - - void* GetDeviceBuffer() { return p_mem_; } - - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } - - void* p_mem_; -}; - int main() { - // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space - // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW - // Hence, we need to adjust the order of stride - std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; - std::array wei_lengths{G, K, C, Y, X}; - std::array wei_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; - std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; - - std::array filter_strides{1, 1}; - std::array filter_dilations{1, 1}; - std::array input_left_pads{1, 1}; - std::array input_right_pads{1, 1}; - - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); - - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - PassThrough>; - - // get device op instances - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - - std::string best_op_name; - int best_op_id = -1; - float best_avg_time = std::numeric_limits::max(); - float best_gb_per_sec = 0; - float best_tflops = 0; - - // profile device operation instances - std::cout << "Run all instances and do timing" << std::endl; - - for(int i = 0; i < op_ptrs.size(); ++i) - { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - std::string op_name = op_ptr->GetTypeString(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - - std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X; - std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + - sizeof(WeiDataType) * G * K * Y * X * C + - sizeof(OutDataType) * N * Ho * Wo * G * K; - - float tflops = static_cast(flop) / 1.E9 / avg_time; - float gb_per_sec = num_bytes / 1.E6 / avg_time; - - std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; - - if(tflops > best_tflops) - { - best_op_id = i; - best_op_name = op_name; - best_avg_time = avg_time; - best_gb_per_sec = gb_per_sec; - best_tflops = tflops; - } - } - else - { - std::cerr << op_name << " does not support this problem" << std::endl; - } - } - - if(best_op_id < 0) - { - std::cerr << "no suitable instance" << std::endl; - return EXIT_FAILURE; - } - - std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops - << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; - - // run the best intance - { - auto& op_ptr = op_ptrs[best_op_id]; - std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() - << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - wei_lengths, - wei_strides, - {}, - {}, - out_lengths, - out_strides, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - } - - std::cout << "Done" << std::endl; - } + return run_grouped_conv_fwd({N, Hi, Wi, G, C}, {G, K, Y, X, C}, {N, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; } diff --git a/client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp diff --git a/client_example/16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp similarity index 100% rename from client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp rename to client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index 23311b4024..5279e3dfcf 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -7,22 +7,6 @@ endif() if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp) target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations) - - add_executable(client_conv3d_fwd_fp8 conv3d_fwd_fp8.cpp) - target_link_libraries(client_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations) -endif() - -if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) - add_executable(client_conv3d_fwd_bf8 conv3d_fwd_bf8.cpp) - target_link_libraries(client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations) -endif() - -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) - add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp) - target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) - - add_executable(client_conv3d_fwd_bf8_fp8 conv3d_fwd_bf8_fp8.cpp) - target_link_libraries(client_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) endif() if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index dc1824931e..9473b4aba0 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.1.1 +rocm-docs-core==1.1.3 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 9a451d9708..8941877aea 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -103,7 +103,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==1.1.1 +rocm-docs-core==1.1.3 # via -r requirements.in six==1.16.0 # via diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 5465adb779..fd9f5cd89d 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -44,6 +44,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endif() endforeach() endif() + + if(INSTANCES_ONLY) + set(EX_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(EX_TARGETS ${GPU_TARGETS}) + endif() + #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") @@ -53,23 +60,30 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endforeach() #Do not build any XDL examples if gfx9 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") message("removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") message("removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #only continue if there are some source files left on the list if(FILE_NAME) + if(FILE_NAME MATCHES "_xdl") + list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) + elseif(FILE_NAME MATCHES "_wmma") + list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + endif() + set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) add_dependencies(examples ${EXAMPLE_NAME}) add_dependencies(check ${EXAMPLE_NAME}) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) @@ -118,6 +132,12 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endif() endforeach() endif() + + if(INSTANCES_ONLY) + set(EX_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(EX_TARGETS ${GPU_TARGETS}) + endif() #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") @@ -127,23 +147,30 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endforeach() #Do not build any XDL examples if gfx9 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") message("removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) - if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") message("removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #only continue if there are some source files left on the list if(FILE_NAME) + if(FILE_NAME MATCHES "_xdl") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) + elseif(FILE_NAME MATCHES "_wmma") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + endif() + set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) add_dependencies(examples ${EXAMPLE_NAME}) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(result 0) endif() diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index a3248e2a5e..0bb5408772 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -34,6 +34,7 @@ args: if not equal to h, then this is GQA/MQA case -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary + also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode) -s_k seqlen_k, -1 means equal to s (default:-1) -d head dim for q, k (default:128) -d_v head dim for v, -1 means equal to d (default:-1) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index b1f7d5b1d1..5f887f0655 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -44,11 +44,18 @@ auto create_args(int argc, char* argv[]) "-1", "num of head, for k/v, -1 means equal to h\n" "if not equal to h, then this is GQA/MQA case") - .insert("s", - "3328", - "seqlen_q. if group-mode, means the average value of seqlen_q\n" - "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") + .insert( + "s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n" + "also with \"-s=s0,s1,s2...\" comma seperated int to set per batch seqlen(group-mode)") .insert("s_k", "-1", "seqlen_k, -1 means equal to s") + .insert("s_kpad", + "-1", + "seqlen_k stride between 2 tokens, currently used in group-mode only\n" + "for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n" + "along seqlen, instead of packed. same as xformer kv_padding") .insert("d", "128", "head dim for q, k") .insert("d_v", "-1", "head dim for v, -1 means equal to d") .insert("scale_s", @@ -106,6 +113,7 @@ auto create_args(int argc, char* argv[]) .insert("p_drop", "0", "0~1 probability of dropout") .insert("drop_seed", "1", "seed for random number generator") .insert("drop_offset", "0", "offset for random number generator") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); @@ -165,10 +173,20 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } - ck_tile::index_t seqlen_q = arg_parser.get_int("s"); - ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - if(seqlen_k < 0) - seqlen_k = seqlen_q; + auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode, + batch, + arg_parser.get_str("s"), + arg_parser.get_str("s_k"), + arg_parser.get_str("s_kpad")); + +#if 0 + // clang-format off + std::cout << "seqlen_qs:"; for(auto xx : seqlen_qs) { std::cout << xx << ","; } std::cout << std::endl; + std::cout << "seqlen_ks:"; for(auto xx : seqlen_ks) { std::cout << xx << ","; } std::cout << std::endl; + std::cout << "seqlen_kpads:"; for(auto xx : seqlen_kpads) { std::cout << xx << ","; } std::cout << std::endl; + // clang-format on +#endif + ck_tile::index_t hdim_q = arg_parser.get_int("d"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); if(hdim_v < 0) @@ -217,7 +235,8 @@ bool run(const ck_tile::ArgParser& arg_parser) bool lse = arg_parser.get_bool("lse"); bias_info bias = bias_info::decode(arg_parser.get_str("bias")); - mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); + mask_info mask = mask_info::decode( + arg_parser.get_str("mask"), seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore float p_drop = arg_parser.get_float("p_drop"); uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); @@ -245,11 +264,16 @@ bool run(const ck_tile::ArgParser& arg_parser) int stream_repeat = arg_parser.get_int("repeat"); bool kname = arg_parser.get_bool("kname"); - ck_tile::stream_config stream_config{ - nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat}; + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + stream_warmup, + stream_repeat, + arg_parser.get_str("timer") == std::string("gpu")}; - const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); - const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); + const auto seqstart_q_host = to_seqstarts(seqlen_qs); + const auto seqstart_k_host = to_seqstarts(seqlen_ks); + const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); using TypeConfig = FmhaFwdTypeConfig; @@ -312,9 +336,11 @@ bool run(const ck_tile::ArgParser& arg_parser) // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); const ck_tile::index_t shape_seqlen_q = - (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); + (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); const ck_tile::index_t shape_seqlen_k = - (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); + (mode == mode_enum::batch ? seqlen_ks[0] + : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() + : seqstart_k_with_padding_host.back())); ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); @@ -421,6 +447,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t)); ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); @@ -429,7 +456,9 @@ bool run(const ck_tile::ArgParser& arg_parser) v_buf.ToDevice(v_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); - seqstart_k.ToDevice(seqstart_k_host.data()); + seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() + : seqstart_k_with_padding_host.data()); + seqlen_k_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr : seqlen_ks.data()); alibi_slope_buf.ToDevice(alibi_slope_host.data()); // clang-format off @@ -445,7 +474,9 @@ bool run(const ck_tile::ArgParser& arg_parser) const std::string prec = arg_parser.get_str("prec"); std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch - << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k + << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] << "/" << seqlen_ks[0] + << (seqlen_kpads[0] < 0 ? "" + : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout << std::flush; @@ -476,7 +507,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return ck_tile::identity{}; }(); - auto fmha_args = [&]() { + auto fmha_args = [&, k_paddings_ = seqlen_kpads]() { assert(nhead % nhead_k == 0); /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// seqlen_k] in this example, hence both the 'batch_stride_bias' & @@ -526,7 +557,7 @@ bool run(const ck_tile::ArgParser& arg_parser) o_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(), - nullptr, + k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(), shape_seqlen_q, shape_seqlen_k, batch, @@ -607,7 +638,10 @@ bool run(const ck_tile::ArgParser& arg_parser) // adjust matrix index according to the mode const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); - const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + const ck_tile::index_t key_offset = + (mode == mode_enum::batch + ? 0 + : (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb])); const auto v_host_ref_lengths = std::array{nhead, hdim_v, real_seqlen_k}; diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 3cd4d2759e..a265a23c8d 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -78,6 +78,11 @@ BOOL_MAP = { "f" : "false" } +TILE_PARTITIONER_MAP = { + "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", + "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", +} + GEN_DIR = "" # in Cmake, have to generate files in same folder FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT @@ -138,7 +143,7 @@ using fmha_epilogue_{F_idx} = {F_spad}, {F_dvpad}>>; using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel, + ck_tile::FmhaFwdKernel<{F_tile_partitioner}, fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>; @@ -156,7 +161,7 @@ float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, k_{{}}, grids, blocks, 0, kargs); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ @@ -394,6 +399,12 @@ class FmhaFwdKernel: F_pipeline : FmhaFwdPipeline mask_impl : str + def get_tp(self) -> str: + if self.F_mode == 'group': + return 'hbs' + else: + return 'shb' + @property def template(self) -> str: kernel_body = str() @@ -418,7 +429,7 @@ class FmhaFwdKernel: F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], @@ -427,12 +438,13 @@ class FmhaFwdKernel: F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()]) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" +\ + return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ self.F_tile.name + '_' + self.F_pipeline.name @property diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 049a7f4b7f..63813e079c 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -29,6 +29,7 @@ $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS done done diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp index e10ae617dc..737efd8256 100644 --- a/example/ck_tile/01_fmha/utils.hpp +++ b/example/ck_tile/01_fmha/utils.hpp @@ -4,12 +4,14 @@ #pragma once #include +#include #include #include #include #include #include #include +#include #include "ck_tile/core/container/span.hpp" @@ -37,12 +39,14 @@ std::vector to_seqstarts(ck_tile::span seqlens) std::vector generate_seqlens(mode_enum mode, unsigned count, - int32_t seqlens_sum, + int32_t seqlen_avg, + int32_t seqlen_max = -1, // if not negative, clamp max std::optional seed = std::nullopt) { assert(0 < count); - std::vector seqlens(count, seqlens_sum); + std::vector seqlens( + count, seqlen_max > 0 ? (seqlen_avg < seqlen_max ? seqlen_avg : seqlen_max) : seqlen_avg); if(mode == mode_enum::group && 1 < count) { @@ -55,7 +59,7 @@ std::vector generate_seqlens(mode_enum mode, std::uniform_int_distribution step_dist(1, count - 1); auto next_step = std::bind(step_dist, std::ref(random_engine)); - for(unsigned repeat = seqlens_sum * (count / 2); 0 < repeat; --repeat) + for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat) { const size_type to_decrease = next_idx(); // make sure each elements of seqlens is always greater than 0 @@ -66,6 +70,11 @@ std::vector generate_seqlens(mode_enum mode, const size_type to_increase = (to_decrease + next_step()) % count; + if(seqlen_max > 0 && seqlens[to_increase] >= seqlen_max) + { + continue; + } + --seqlens[to_decrease]; ++seqlens[to_increase]; } @@ -76,10 +85,91 @@ std::vector generate_seqlens(mode_enum mode, std::vector generate_seqstarts(mode_enum mode, unsigned count, - int32_t seqlens_sum, + int32_t seqlen_avg, + int32_t seqlen_max = -1, std::optional seed = std::nullopt) { - return to_seqstarts(generate_seqlens(mode, count, seqlens_sum, seed)); + return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_max, seed)); +} + +/* + * decode the seqlen string from cmdline + * example (assume batch=3) + * q_val=1,2,3 k_val=4,5,6 -> OK + * q_val=1,2,3 -> OK, k same as q + * q_val=1,2 -> OK, q will rand remaining 1 element, k same as q + * q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element + * q_val=1,2,3,4 -> OK, but ignore exceed one + * + * q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q + * q_val=1,2 k_val=4 -> not OK, k must have same splits with q + */ +std::tuple, + std::vector, + std::vector> +decode_seqlen(mode_enum mode, + ck_tile::index_t batch, + std::string q_val, + std::string k_val, + std::string k_pad_val, + std::optional seed = std::nullopt) +{ +#define _S2I_(str_) static_cast(std::atoi((str_).c_str())) + if(mode == mode_enum::batch) + { + ck_tile::index_t q = _S2I_(q_val); + ck_tile::index_t k = _S2I_(k_val); + auto s_q = std::vector(batch, q); + auto s_k = std::vector(batch, k < 0 ? q : k); + auto s_kpad = std::vector(batch, -1); // TODO: batch not support k_padding + return std::make_tuple(s_q, s_k, s_kpad); + } + else + { + ck_tile::index_t idx = 0; + std::string::size_type pos_q = 0; + std::string::size_type pos_k = 0; + std::string::size_type pos_kp = 0; + std::vector s_q; + std::vector s_k; + std::vector s_kpad; + while(true) + { + auto found_q = q_val.find(',', pos_q); + auto found_k = k_val.find(',', pos_k); + auto found_kp = k_pad_val.find(',', pos_kp); + + ck_tile::index_t q = _S2I_( + q_val.substr(pos_q, found_q == std::string::npos ? found_q : found_q - pos_q)); + ck_tile::index_t k = _S2I_( + k_val.substr(pos_k, found_k == std::string::npos ? found_k : found_k - pos_k)); + ck_tile::index_t kp = _S2I_(k_pad_val.substr( + pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp)); + + s_q.push_back(q); + s_k.push_back(k < 0 ? q : k); + s_kpad.push_back(kp); + idx++; + if(found_q == std::string::npos || idx >= batch) + { + break; + } + pos_q = found_q + 1; + pos_k = found_k == std::string::npos ? pos_k : found_k + 1; + pos_kp = found_kp == std::string::npos ? pos_kp : found_kp + 1; + } + if(idx < batch) + { + auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), s_kpad.back(), seed); + auto rem_k = generate_seqlens(mode, batch - idx, s_k.back(), s_kpad.back(), seed); + + s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); + s_k.insert(s_k.end(), rem_k.begin(), rem_k.end()); + s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back()); + } + return std::make_tuple(s_q, s_k, s_kpad); + } +#undef _S2I_ } int env_get_int(const char* var_name, int default_int) @@ -87,6 +177,6 @@ int env_get_int(const char* var_name, int default_int) char* v = getenv(var_name); int r = default_int; if(v) - r = atoi(v); + r = std::atoi(v); return r; } diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp index 36993d0ae2..9d9974d49d 100644 --- a/include/ck/host_utility/flush_cache.hpp +++ b/include/ck/host_utility/flush_cache.hpp @@ -104,20 +104,25 @@ inline void flush_icache() hip_check_error(hipGetLastError()); } // if TimePrePress == false, return time does not include preprocess's time -template +template float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, - Args& args) + GemmArgs& gemm_args, + Args... args) { #if CK_TIME_KERNEL #define MEDIAN 1 if(stream_config.time_kernel_) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", __func__, @@ -133,7 +138,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, // warm up for(int i = 0; i < stream_config.cold_niters_; ++i) { - kernel<<>>(args); + kernel<<>>(gemm_args, args...); hip_check_error(hipGetLastError()); } @@ -142,7 +147,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, { return 0.0; } - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("Start running %d times...\n", nrepeat); } @@ -172,7 +177,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, preprocess(); } // run real kernel - kernel<<>>(args); + kernel<<>>(gemm_args, args...); hip_check_error(hipGetLastError()); // end real kernel @@ -186,13 +191,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, total_time += cur_time; #endif - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "i: " << i << " cur_time: " << cur_time << std::endl; - printf("args.p_a_grid: %p, args.p_b_grid:%p\n", - static_cast(args.p_a_grid), - static_cast(args.p_b_grid)); + printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n", + static_cast(gemm_args.p_a_grid), + static_cast(gemm_args.p_b_grid)); } } @@ -216,13 +221,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, else { preprocess(); - kernel<<>>(args); + kernel<<>>(gemm_args, args...); hip_check_error(hipGetLastError()); return 0; } #else - kernel<<>>(args); + kernel<<>>(gemm_args, args...); hip_check_error(hipGetLastError()); return 0; diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index 1cdb7f9c5a..a616433ac9 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -20,7 +20,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, #if CK_TIME_KERNEL if(stream_config.time_kernel_) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", __func__, @@ -41,7 +41,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, } const int nrepeat = stream_config.nrepeat_; - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("Start running %d times...\n", nrepeat); } @@ -95,7 +95,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, #if CK_TIME_KERNEL if(stream_config.time_kernel_) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", __func__, @@ -117,7 +117,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, } const int nrepeat = stream_config.nrepeat_; - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { printf("Start running %d times...\n", nrepeat); } diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index f68473c292..c152cbfb1e 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -1952,7 +1952,7 @@ struct Modulo } }; -template +template struct Xor { using LowerIndex = MultiIndex<2>; @@ -1981,8 +1981,15 @@ struct Xor idx_low(Number<0>{}) = idx_up[Number<0>{}]; - idx_low(Number<1>{}) = - idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]); + if constexpr(ApplyModulo) + { + idx_low(Number<1>{}) = + idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]); + } + else + { + idx_low(Number<1>{}) = idx_up[Number<1>{}] ^ idx_up[Number<0>{}]; + } } template {modulus, up_length}; } +template +__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths& low_lengths) +{ + return Xor{low_lengths}; +} + template __host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths) { - return Xor{low_lengths}; + return Xor{low_lengths}; } } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 701dd04f6c..e5e6245cb8 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -795,11 +795,6 @@ struct BlockwiseGemmXdlops_v2 "wrong!"); } - __host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other) - : a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin) - { - } - // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index 4521b2161f..6ab1669e30 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -587,7 +587,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle BatchStrideD1s, BatchStrideE1} { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", " << a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index 37ebe2f85c..34b1d503af 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -658,7 +658,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { { std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 445467be55..e178b8f525 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -719,7 +719,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { arg.Print(); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 4cc60f2836..9d5b74be6c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -53,8 +53,7 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index 6fd8c03232..0b73317c5e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -516,7 +516,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { { std::cout << "arg.a_grid_desc_k0_m_k1_container_{" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index f5c1460f56..13eb23574f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -644,7 +644,7 @@ struct float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << "N " << arg.Conv_N_ << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp index 9015f640ad..28778d825b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -614,7 +614,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << "N " << arg.Conv_N_ << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index e815c0784d..7fa231d4f4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -579,7 +579,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << DeviceOp{}.GetTypeString() << std::endl; std::cout << "N " << arg.Conv_N_ << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index 760e2840d4..3be7313d2b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -431,7 +431,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index de48719398..6e69213513 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -401,7 +401,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl; std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index 149aca7e3e..b84e181306 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -1272,7 +1272,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "arg.a_grid_desc_k0_m_k1_container_{" << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp index 4398724553..de8f35a640 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -1220,7 +1220,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "arg.a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index d3af5e63d3..b1784b3858 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -334,7 +334,7 @@ struct DeviceGemmDl : public DeviceGemm #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" namespace ck { namespace tensor_operation { namespace device { template + index_t NumBatchToMerge, + bool HasMainKBlockLoop, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + index_t MinimumOccupancy = 1, + TailNumber TailNum = TailNumber::Full> __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_xdlops_bwd_weight( - const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const index_t batch_count, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, - p_shared, - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_b_k0_m_k1_grid_desc; - ignore = b_b_k0_n_k1_grid_desc; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = batch_count; - ignore = block_2_ctile_map; - ignore = compute_ptr_offset_of_batch; + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} - compute_ptr_offset_of_batch.GetAPtrOffset(0); - compute_ptr_offset_of_batch.GetBPtrOffset(0); - compute_ptr_offset_of_batch.GetCPtrOffset(0); +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); +#else + ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -121,7 +163,7 @@ template + BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, + index_t NumBatchToMerge = 1, + typename ComputeTypeA = InDataType, + typename ComputeTypeB = ComputeTypeA> struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle : public DeviceGroupedConvBwdWeight { + static_assert(is_same_v); + static_assert(is_same_v); + static_assert(is_same_v); + using DeviceOp = DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle; using ADataType = OutDataType; @@ -183,52 +232,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle static constexpr auto K1Number = Number{}; - static constexpr auto conv_to_gemm_transformer = + static constexpr auto conv_to_gemm_transformer_v2 = + TransformConvBwdWeightToGemmV2{}; + + static constexpr auto conv_to_gemm_transformer_v1 = TransformConvBwdWeightToGemm{}; - // Bytes per 32 lds bank: 32 * 4 bytes - static constexpr auto BankLength = 128; - static constexpr auto ElePerBank = BankLength / sizeof(ADataType); - - // M1 & M0 - static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; - static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; - static constexpr auto ABlockLdsM1Padding = 4; - - // N1 & N0 - static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; - static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; - static constexpr auto BBlockLdsN1Padding = 4; - - template ::type = false> - static auto GetABCGridDesc() - { - const ck::index_t dim = 1; - const ck::index_t batch = 1; - const std::array lengths{1}; - const std::array strides{1, 1, 1, 1}; - const std::array params{1}; - return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( - dim, - dim, - dim, - lengths, - lengths, - lengths, - strides, - strides, - strides, - params, - params, - params, - params, - batch); - } + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default; template ::type = false> static auto GetABCGridDesc() @@ -238,21 +259,21 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const std::array lengths{1, 1}; const std::array strides{1, 1, 1, 1, 1}; const std::array params{1, 1}; - return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( - dim, - dim, - dim, - lengths, - lengths, - lengths, - strides, - strides, - strides, - params, - params, - params, - params, - batch); + return conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); } template ::type = false> @@ -263,21 +284,71 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const std::array lengths{1, 1, 1}; const std::array strides{1, 1, 1, 1, 1, 1}; const std::array params{1, 1, 1}; - return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>( - dim, - dim, - dim, - lengths, - lengths, - lengths, - strides, - strides, - strides, - params, - params, - params, - params, - batch); + return conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetElementwiseCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch)[I2]; + } + + template ::type = false> + static auto GetElementwiseCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch)[I2]; } using ABCGridDescs = decltype(GetABCGridDesc()); @@ -285,60 +356,56 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle using AGridDesc_K0_M_K1 = remove_cvref_t; using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; + using CElementwiseGridDesc_M_N = + remove_cvref_t())>; - using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< - BlockSize, - ADataType, - BDataType, - AccDataType, - AccDataType, - InMemoryDataOperationEnum::AtomicAdd, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, - AElementwiseOperation, - BElementwiseOperation, - element_wise::PassThrough, - MPerBlock, - NPerBlock, - K0PerBlock, - MPerXdl, - NPerXdl, - K1, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - ABlockLdsAddExtraM, - ABlockLdsM1PerBlock, - ABlockLdsM0PerBlock, - ABlockLdsM1Padding, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, // BThreadTransferSrcResetCoordinateAfterRun, - BBlockLdsAddExtraN, - BBlockLdsN1PerBlock, - BBlockLdsN0PerBlock, - BBlockLdsN1Padding, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CBlockTransferScalarPerVector_NWaveNPerXdl, - CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - true, - true, - 1, - PipelineVersion::v1, - ComputeTypeA, - ComputeTypeB>; + using GridwiseGemm = + GridwiseGemm_xdl_cshuffle_v3; static constexpr index_t ClusterLengthMPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); @@ -347,8 +414,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; using GridwiseElementwise = - GridwiseElementwise, - Tuple, + GridwiseElementwise, + Tuple, Tuple, Tuple, Block2TileMapElementwise, @@ -366,10 +433,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); - - using Block2CTileMap = - decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + CGridDesc_M_N{}, 1, 1)); struct Argument : public BaseArgument { @@ -395,11 +460,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle : p_a_grid_{p_out_grid}, p_b_grid_{p_in_grid}, p_e_grid_{p_wei_grid}, - a_grid_desc_kbatch_k0_m_k1_{}, - b_grid_desc_kbatch_k0_n_k1_{}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, ce_grid_desc_m_n_{}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_ctile_map_{}, compute_ptr_offset_of_batch_{}, M01_{M01}, N01_{N01}, @@ -430,7 +494,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle begin(output_spatial_lengths_)); const auto descs = - conv_to_gemm_transformer + conv_to_gemm_transformer_v2 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( Conv_N_, Conv_K_, @@ -447,15 +511,34 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle input_right_pads, k_batch_); - a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; - b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; - ce_grid_desc_m_n_ = descs[I2]; + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + ce_grid_desc_m_n_ = descs[I2]; + + ce_elementwise_grid_desc_m_n_ = + conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_)[I2]; - block_2_ctile_map_ = - GridwiseGemm::MakeCBlockClusterAdaptor(ce_grid_desc_m_n_, M01, N01, k_batch_); elementwise_block_2_ctile_map_ = Block2TileMapElementwise{ ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)}; + const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); + // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0]; @@ -465,16 +548,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle end(filter_spatial_lengths_), index_t{1}, std::multiplies<>{}); - - if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, - b_grid_desc_kbatch_k0_n_k1_, - ce_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock( - ce_grid_desc_m_n_); - } + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ce_grid_desc_m_n_, + GridwiseGemm::CalculateMBlock(GemmM), + GridwiseGemm::CalculateNBlock(GemmN)); } std::size_t GetWorkspaceSizeBytes() const @@ -486,12 +564,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle const BDataType* p_b_grid_; EDataType* p_e_grid_; - AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N ce_grid_desc_m_n_; + CElementwiseGridDesc_M_N ce_elementwise_grid_desc_m_n_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; - Block2CTileMap block_2_ctile_map_; Block2TileMapElementwise elementwise_block_2_ctile_map_; // for computing batch offset @@ -525,96 +603,676 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle void ShowInfo(const Argument& arg) { - std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; std::cout << "arg.ce_grid_desc_m_n_{" << arg.ce_grid_desc_m_n_.GetLength(I0) << ", " << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + float RunGemmV3(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.ce_grid_desc_m_n_, - arg.block_2_ctile_map_)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); - } + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); - const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + AccDataType* p_c_grid = type_convert(arg.p_workspace_); - auto launch_gemm_kernel = [&](auto has_main_k_block_loop) { - AccDataType* p_c_grid = type_convert(arg.p_workspace_); - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_; + // nullptr for output, will be set after workspace set + typename GridwiseGemm::Argument gemm_arg{arg.p_a_grid_, + arg.p_b_grid_, + p_c_grid, + GemmM, + GemmN, + GemmK, + I0, + I0, + I0, + arg.k_batch_}; - constexpr bool has_main_loop = has_main_k_block_loop.value; + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( + gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumBatchToMerge); - auto preprocess = [&]() { - hip_check_error(hipMemsetAsync( - p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); - }; + float ave_time = 0; - const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< - GridwiseGemm, - ADataType, - BDataType, - AccDataType, - OutElementwiseOperation, - InElementwiseOperation, - element_wise::PassThrough, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - ComputePtrOffsetOfStridedBatch, - has_main_loop>; + index_t k_grain = gemm_arg.KBatch * KPerBlock; + index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * (KPerBlock); - return launch_and_time_kernel_with_preprocess( - stream_config, - preprocess, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - p_c_grid, - arg.a_element_op_, - arg.b_element_op_, - element_wise::PassThrough{}, - arg.Conv_G_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_, - arg.compute_ptr_offset_of_batch_); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto num_k_per_block = + arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + + const auto clear_workspace = [&]() { + hip_check_error(hipMemsetAsync( + gemm_arg.p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); }; + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + ck::utility::RotatingMemWrapper rotating_mem( + gemm_arg_, + stream_config.rotating_count, + gemm_arg_.M * gemm_arg_.K * sizeof(ADataType), + gemm_arg_.K * gemm_arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + clear_workspace(); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + else + { + ave_time = launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumBatchToMerge, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + return ave_time; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { auto launch_elementwise_kernel = [&]() { const AccDataType* p_c_grid = type_convert(arg.p_workspace_); - const index_t grid_size = - arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * - arg.Conv_G_; + const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize( + arg.ce_elementwise_grid_desc_m_n_) * + arg.Conv_G_; std::array in_out_batch_strides = { arg.compute_ptr_offset_of_batch_.BatchStrideC_}; const auto kernel = kernel_batched_elementwise, - ck::Tuple, + ck::Tuple, + ck::Tuple, ck::Tuple, ck::Tuple, Block2TileMapElementwise, @@ -627,8 +1285,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle dim3(grid_size), dim3(BlockSize), 0, - make_tuple(arg.ce_grid_desc_m_n_), - make_tuple(arg.ce_grid_desc_m_n_), + make_tuple(arg.ce_elementwise_grid_desc_m_n_), + make_tuple(arg.ce_elementwise_grid_desc_m_n_), make_tuple(p_c_grid), make_tuple(arg.p_e_grid_), arg.elementwise_block_2_ctile_map_, @@ -638,16 +1296,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle in_out_batch_strides); }; - float avg_time = 0; - if(has_main_k0_block_loop) - { - avg_time = launch_gemm_kernel(integral_constant{}); - } - else - { - avg_time = launch_gemm_kernel(integral_constant{}); - } - + float avg_time = RunGemmV3(arg, stream_config); avg_time += launch_elementwise_kernel(); return avg_time; } @@ -667,6 +1316,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + typename GridwiseGemm::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / K1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + // Check this here, it allows to use other instances from factory even // if workspace is not allocated if(!arg.p_workspace_) @@ -723,10 +1389,38 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle } } + if constexpr(NumBatchToMerge > 1) + { + // support only if whole M and N can be proccessed on one block + if(!(GemmM <= MPerBlock && GemmN <= NPerBlock)) + { + return false; + } + if(!(arg.Conv_C_ == 1 && arg.Conv_K_ == 1)) + { + return false; + } + if(arg.Conv_G_ % NumBatchToMerge != 0) + { + return false; + } + } + + if(!(arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0)) + { + if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1)) + { + return false; + } + if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1)) + { + return false; + } + } + // vector load A/B matrix from global memory - if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && - arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && - arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1)) { return false; } @@ -737,11 +1431,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle return false; } - // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.ce_grid_desc_m_n_, - arg.block_2_ctile_map_); + return true; } bool IsSupportedArgument(const BaseArgument* p_arg) override @@ -840,13 +1530,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle { auto str = std::stringstream(); + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + // clang-format off str << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle" << "<" << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock << ", " + << KPerBlock << ", " << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " << K1 << ", " << MXdlPerWave << ", " @@ -857,7 +1558,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle << BBlockTransferDstScalarPerVector_K1 << ", " << CShuffleMXdlPerWavePerShuffle << ", " << CShuffleNXdlPerWavePerShuffle << ", " - << CBlockTransferScalarPerVector_NWaveNPerXdl + << CBlockTransferScalarPerVector_NWaveNPerXdl << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " + << NumBatchToMerge << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index bf8788a3b2..1f60818e39 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -45,8 +45,7 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t KBatch = 1; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index a88c7b4fb7..ac392cddc4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -553,7 +553,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "The group count is not equal to sum of skipped groups " "and kernel args size!" @@ -835,7 +836,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg); if(not group_arg_valid) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[" << __func__ << "] group id: " << i << " has invalid GridwiseGemm settings!" << std::endl; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 403bc7fad6..36cbd1cd26 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -620,7 +620,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop GridwiseGemm::template CheckTensorTransfersValidity( M, N, K))) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "The provided GEMM problem size (M,N,K) [" << M << "," << N << "," << K << "] are not supported by current template parameters!" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 90c0593b28..658f323516 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -514,7 +514,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "The group count is not equal to sum of skipped groups " "and kernel args size!" @@ -545,7 +545,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK(p_a_grid, @@ -80,7 +79,7 @@ __global__ void ignore = b_element_op; ignore = c_element_op; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx1100__)) +#endif // end of if (defined(__gfx11__)) } // Assume B is Col-Major diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp new file mode 100644 index 0000000000..d2a06ba9af --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp @@ -0,0 +1,1369 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseGemm_xdl_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t k_batch_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(BDataType); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t k_id = 0) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( + make_multi_index(static_cast(blockIdx.x))); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k_id, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k_id, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + (KPerBlock * problem.KBatch)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t k_id = 0) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( + make_multi_index(static_cast(blockIdx.x))); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k_id, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k_id, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + (KPerBlock * problem.KBatch)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index dea92bc576..50e6f68e61 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -34,8 +34,7 @@ __global__ void // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); @@ -48,7 +47,7 @@ __global__ void karg); #else ignore = karg; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template {}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(AK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -671,7 +669,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(AK1Number)), @@ -742,8 +740,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(Number{}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(BK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -805,7 +803,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(BK1Number)), @@ -935,7 +933,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(!(karg.M % MPerBlock == 0)) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ @@ -952,7 +950,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(!(karg.N % NPerBlock == 0)) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ @@ -971,7 +969,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 auto K_t = karg.KBatch * KPerBlock; if(!(karg.K % K_t == 0)) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " << karg.K << " " << __FILE__ << ":" << __LINE__ @@ -995,7 +993,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.K % ABlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg K (" << karg.K << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" @@ -1009,7 +1007,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.M % ABlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg M (" << karg.M << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" @@ -1024,7 +1022,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.N % BBlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg N (" << karg.N << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" @@ -1038,7 +1036,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.K % BBlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg K (" << karg.K << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" @@ -1053,7 +1051,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg N (" << karg.N << ") value is not a multiple of " @@ -1069,7 +1067,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg M (" << karg.M << ") value is not a multiple of " @@ -1084,7 +1082,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 if constexpr(is_same, bhalf_t>::value) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index fdafa9ca5c..f9071bd29d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -33,8 +33,7 @@ __global__ void // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -49,7 +48,7 @@ __global__ void karg.c_element_op); #else ignore = karg; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template {}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(AK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -849,7 +847,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(AK1Number)), @@ -920,8 +918,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(Number{}, - Number{})), + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), make_pass_through_transform(BK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); @@ -983,7 +981,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 make_tuple( make_pass_through_transform(Number{}), make_pass_through_transform(Number{}), - make_xor_transform( + make_xor_with_modulo_transform( make_tuple(Number{}, Number{})), make_pass_through_transform(Number{}), make_pass_through_transform(BK1Number)), @@ -1113,7 +1111,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(!(karg.M % MPerBlock == 0)) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ @@ -1130,7 +1128,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(!(karg.N % NPerBlock == 0)) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ @@ -1149,7 +1147,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 auto K_t = karg.KBatch * KPerBlock; if(!(karg.K % K_t == 0)) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " << karg.K << " " << __FILE__ << ":" << __LINE__ @@ -1173,7 +1171,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.K % ABlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg K (" << karg.K << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" @@ -1187,7 +1185,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.M % ABlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg M (" << karg.M << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" @@ -1202,7 +1200,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.N % BBlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg N (" << karg.N << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" @@ -1216,7 +1214,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.K % BBlockTransferSrcScalarPerVector != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg K (" << karg.K << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" @@ -1231,7 +1229,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg N (" << karg.N << ") value is not a multiple of " @@ -1247,7 +1245,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 { if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) { - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg M (" << karg.M << ") value is not a multiple of " diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp index 94306a4c95..bac8c32886 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -38,8 +38,7 @@ __global__ void const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; @@ -52,7 +51,7 @@ __global__ void ignore = a_element_op; ignore = b_element_op; ignore = c_element_op; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +#endif // end of if (defined(__gfx9__)) } template +struct TransformConvBwdWeightToGemmV2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t BatchStride = output_strides[0]; + const index_t WoStride = output_strides[4]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, NumBatchToMerge, K), + make_tuple(WoStride, BatchStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t BatchStride = input_strides[0]; + const index_t NStride = input_strides[1]; + const index_t HiStride = input_strides[3]; + const index_t WiStride = input_strides[4]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, NumBatchToMerge, C), + make_tuple(WiStride, BatchStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, NumBatchToMerge, C), + make_tuple(NStride, HiStride, WiStride, BatchStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + const auto XStride = weights_strides[4]; + const auto BatchStride = weights_strides[0]; + // Add NumBatchToMerge for Batch+M dimension and, 1 as a placehorder + // for Batch+N dimension + const auto desc = make_naive_tensor_descriptor( + make_tuple(NumBatchToMerge, K, Y * X, 1, C), + make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); + // Padd 1 to NumBatchToMerge + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(K), + make_pass_through_transform(Y * X), + make_pad_transform(1, 0, NumBatchToMerge - 1), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + // We need only matrices from diagonal. Xor returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumBatchToMerge == 1 || NumBatchToMerge == 2 || NumBatchToMerge == 4 || + NumBatchToMerge == 8 || NumBatchToMerge == 16 || NumBatchToMerge == 32 || + NumBatchToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_xor_transform(make_tuple(NumBatchToMerge, NumBatchToMerge)), + make_pass_through_transform(K), + make_pass_through_transform(Y * X), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NumBatchToMerge, K)), + make_merge_transform(make_tuple(Y * X, NumBatchToMerge, C))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Do, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t BatchStride = output_strides[0]; + const index_t WoStride = output_strides[5]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, NumBatchToMerge, K), + make_tuple(WoStride, BatchStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Di, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t BatchStride = input_strides[0]; + const index_t NStride = input_strides[1]; + const index_t DiStride = input_strides[3]; + const index_t HiStride = input_strides[4]; + const index_t WiStride = input_strides[5]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, NumBatchToMerge, C), + make_tuple(WiStride, BatchStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, NumBatchToMerge, C), + make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Z, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + const auto XStride = weights_strides[5]; + const auto BatchStride = weights_strides[0]; + // Add NumBatchToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension + const auto desc = make_naive_tensor_descriptor( + make_tuple(NumBatchToMerge, K, Z * Y * X, 1, C), + make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); + // Padd 1 to NumBatchToMerge + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(K), + make_pass_through_transform(Z * Y * X), + make_pad_transform(1, 0, NumBatchToMerge - 1), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + // We need only matrices from diagonal. Xor returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumBatchToMerge == 1 || NumBatchToMerge == 2 || NumBatchToMerge == 4 || + NumBatchToMerge == 8 || NumBatchToMerge == 16 || NumBatchToMerge == 32 || + NumBatchToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_xor_transform(make_tuple(NumBatchToMerge, NumBatchToMerge)), + make_pass_through_transform(K), + make_pass_through_transform(Z * Y * X), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NumBatchToMerge, K)), + make_merge_transform(make_tuple(Z * Y * X, NumBatchToMerge, C))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K * NumBatchToMerge; + const index_t GemmN = C * X * Y * NumBatchToMerge; + + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumBatchToMerge, GemmN / NumBatchToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5>{}, + Sequence<6>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, NumBatchToMerge, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K * NumBatchToMerge; + const index_t GemmN = C * Z * X * Y * NumBatchToMerge; + + const auto PadGemmM = MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumBatchToMerge, GemmN / NumBatchToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumBatchToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{}, + Sequence<8>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, NumBatchToMerge, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7, 8>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } // function end +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp index 0b6504e528..6455402dcb 100644 --- a/include/ck/utility/env.hpp +++ b/include/ck/utility/env.hpp @@ -124,7 +124,7 @@ struct EnvVar #define CK_DECLARE_ENV_VAR_STR(name) CK_DECLARE_ENV_VAR(name, std::string, "") -#define ENV(name) \ +#define CK_ENV(name) \ ck::env::name {} template diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index a123623f5b..9c6e85f013 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -29,6 +29,25 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz return __builtin_bit_cast(int32x4_t, res); } +namespace impl { +// below type indicate the data type used for buffer load inline asm +// clang-format off +template struct buffer_load_trait; + +template struct buffer_load_trait<16, T> { using payload_t = fp32x4_t; }; +template struct buffer_load_trait<8 , T> { using payload_t = fp32x2_t; }; +template struct buffer_load_trait<4 , T> { using payload_t = float; }; +template struct buffer_load_trait<2 , T> { using payload_t = float; }; +template struct buffer_load_trait<1 , T> { using payload_t = float; }; + +#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA +template<> struct buffer_load_trait<16, thread_buffer> { using payload_t = bf16x8_t; }; +template<> struct buffer_load_trait<8 , thread_buffer> { using payload_t = bf16x4_t; }; +template<> struct buffer_load_trait<4 , thread_buffer> { using payload_t = bf16x2_t; }; +#endif +// clang-format on +} // namespace impl + // TODO: glc/slc/... template struct buffer_load; @@ -48,7 +67,7 @@ struct buffer_load<16> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 16); - using mbuf_t = fp32x4_t; + using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) @@ -68,7 +87,7 @@ struct buffer_load<8> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 8); - using mbuf_t = fp32x2_t; + using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) @@ -88,7 +107,7 @@ struct buffer_load<4> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 4); - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) @@ -108,7 +127,7 @@ struct buffer_load<2> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) @@ -128,7 +147,7 @@ struct buffer_load<1> index_t /*flag*/ = 0) { static_assert(sizeof(T) == 4); - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) @@ -152,7 +171,7 @@ struct buffer_load_if<16> { static_assert(sizeof(T) == 16); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x4_t; + using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; static_assert(sizeof(mbuf_t) == sizeof(T)); asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" @@ -177,7 +196,7 @@ struct buffer_load_if<8> { static_assert(sizeof(T) == 8); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x2_t; + using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" @@ -201,7 +220,7 @@ struct buffer_load_if<4> { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" @@ -225,7 +244,7 @@ struct buffer_load_if<2> { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" @@ -249,7 +268,7 @@ struct buffer_load_if<1> { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; asm volatile( "v_cmpx_le_u32 exec, 1, %5\n" "buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 82b6953b5e..10045d8f7d 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -3,6 +3,21 @@ #pragma once +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) +#define __gfx9__ +#endif +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#define __gfx94__ +#endif +#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) +#define __gfx103__ +#endif +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) +#define __gfx11__ +#endif + #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" @@ -109,15 +124,13 @@ // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) // for GPU code +#elif defined(__gfx9__) // for GPU code #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 #endif -#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__)) // for GPU code +#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code #define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 #else #define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 @@ -137,13 +150,12 @@ #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff -#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) // for GPU code +#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || \ + defined(__gfx9__) // for GPU code #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000 -#elif defined(__gfx1030__) // for GPU code +#elif defined(__gfx103__) // for GPU code #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000 -#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code +#elif defined(__gfx11__) // for GPU code #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #endif @@ -159,3 +171,7 @@ #ifndef CK_TILE_FMHA_FWD_FAST_EXP2 #define CK_TILE_FMHA_FWD_FAST_EXP2 0 #endif + +#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA +#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1 +#endif diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index bad1009f2c..56ca44e720 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -55,7 +55,7 @@ struct alignas(1) float8_e4m3_t { static constexpr int exponent = 4; static constexpr int mantissa = 3; -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) static constexpr int bias = 1 << (exponent - 1); // NANOO #else static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE @@ -113,7 +113,7 @@ struct alignas(1) float8_e5m2_t { static constexpr int exponent = 5; static constexpr int mantissa = 2; -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) static constexpr int bias = 1 << (exponent - 1); // NANOO #else static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE @@ -470,7 +470,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x) { constexpr int seed = 42; uint32_t rng = prand_generator_t{}(reinterpret_cast(&x), x); -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float max_fp8 = 240.0f; x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); union @@ -500,7 +500,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x) { constexpr int seed = 42; uint32_t rng = prand_generator_t{}(reinterpret_cast(&x), x); -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) union { float fval; @@ -526,7 +526,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x) CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float max_fp8 = 240.0f; x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); union @@ -554,7 +554,7 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x) } CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) union { float fval; @@ -598,7 +598,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant) CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float fval; uint32_t i32val = static_cast(x); fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); @@ -612,7 +612,7 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) float fval; uint32_t i32val = static_cast(x); fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); @@ -656,7 +656,7 @@ struct numeric_traits { static constexpr int exp = 4; static constexpr int mant = 3; -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) static constexpr int bias = 8; #else static constexpr int bias = 7; @@ -668,7 +668,7 @@ struct numeric_traits { static constexpr int exp = 5; static constexpr int mant = 2; -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) static constexpr int bias = 16; #else static constexpr int bias = 15; // IEEE diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index c616b6939f..752145f711 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -129,8 +129,8 @@ constexpr double fp16_to_double_hip(const fp16_hip_t& x) CK_TILE_HOST_DEVICE constexpr fp16_hip_t float_to_fp16_hip(const float& x) { - return __float2half(x); - // return static_cast(x); + // return __float2half(x); + return static_cast(x); } CK_TILE_HOST_DEVICE diff --git a/include/ck_tile/core/numeric/integral_constant.hpp b/include/ck_tile/core/numeric/integral_constant.hpp index ea7a67abcc..33c24da8c5 100644 --- a/include/ck_tile/core/numeric/integral_constant.hpp +++ b/include/ck_tile/core/numeric/integral_constant.hpp @@ -56,7 +56,6 @@ CK_TILE_LEFT_UNARY_OP(+) CK_TILE_LEFT_UNARY_OP(-) CK_TILE_LEFT_UNARY_OP(~) CK_TILE_LEFT_UNARY_OP(!) -CK_TILE_LEFT_UNARY_OP(*) CK_TILE_BINARY_OP(+) CK_TILE_BINARY_OP(-) diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 90ad94b12b..48762b7225 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -112,7 +112,7 @@ namespace impl { template CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) // This API is designed to use the _pk_ serious of function constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 07a51ff9b6..09030fa6df 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -21,3 +21,4 @@ #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/stream_config.hpp" +#include "ck_tile/host/timer.hpp" diff --git a/include/ck_tile/host/device_memory.hpp b/include/ck_tile/host/device_memory.hpp index 91463a06a9..7c8549f74f 100644 --- a/include/ck_tile/host/device_memory.hpp +++ b/include/ck_tile/host/device_memory.hpp @@ -27,7 +27,14 @@ struct DeviceMem DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {} DeviceMem(std::size_t mem_size) : mMemSize(mem_size) { - HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + if(mMemSize != 0) + { + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + } + else + { + mpDeviceBuf = nullptr; + } } void Realloc(std::size_t mem_size) { @@ -36,7 +43,14 @@ struct DeviceMem HIP_CHECK_ERROR(hipFree(mpDeviceBuf)); } mMemSize = mem_size; - HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + if(mMemSize != 0) + { + HIP_CHECK_ERROR(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + } + else + { + mpDeviceBuf = nullptr; + } } void* GetDeviceBuffer() const { return mpDeviceBuf; } std::size_t GetBufferSize() const { return mMemSize; } @@ -47,15 +61,18 @@ struct DeviceMem HIP_CHECK_ERROR( hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); } - else - { - throw std::runtime_error("ToDevice with an empty pointer"); - } + // else + // { + // throw std::runtime_error("ToDevice with an empty pointer"); + // } } void ToDevice(const void* p, const std::size_t cpySize) const { - HIP_CHECK_ERROR( - hipMemcpy(mpDeviceBuf, const_cast(p), cpySize, hipMemcpyHostToDevice)); + if(mpDeviceBuf) + { + HIP_CHECK_ERROR( + hipMemcpy(mpDeviceBuf, const_cast(p), cpySize, hipMemcpyHostToDevice)); + } } void FromDevice(void* p) const { @@ -63,14 +80,17 @@ struct DeviceMem { HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); } - else - { - throw std::runtime_error("FromDevice with an empty pointer"); - } + // else + // { + // throw std::runtime_error("FromDevice with an empty pointer"); + // } } void FromDevice(void* p, const std::size_t cpySize) const { - HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); + if(mpDeviceBuf) + { + HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost)); + } } void SetZero() const { @@ -82,13 +102,16 @@ struct DeviceMem template void SetValue(T x) const { - if(mMemSize % sizeof(T) != 0) + if(mpDeviceBuf) { - throw std::runtime_error("wrong! not entire DeviceMem will be set"); - } + if(mMemSize % sizeof(T) != 0) + { + throw std::runtime_error("wrong! not entire DeviceMem will be set"); + } - // TODO: call a gpu kernel to set the value (?) - set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); + // TODO: call a gpu kernel to set the value (?) + set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); + } } ~DeviceMem() { diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index 7053888abd..e9c5a0c254 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/host/timer.hpp" #include #include @@ -14,153 +15,92 @@ template -CK_TILE_HOST float launch_and_time_kernel(const stream_config& s, - F kernel, - dim3 grid_dim, - dim3 block_dim, - std::size_t lds_byte, - Args... args) -{ -#if CK_TILE_TIME_KERNEL - if(s.time_kernel_) - { - // warm up - for(int i = 0; i < s.cold_niters_; ++i) - { - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - } - - const int nrepeat = s.nrepeat_; - hipEvent_t start, stop; - - HIP_CHECK_ERROR(hipEventCreate(&start)); - HIP_CHECK_ERROR(hipEventCreate(&stop)); - - HIP_CHECK_ERROR(hipDeviceSynchronize()); - HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_)); - - for(int i = 0; i < nrepeat; ++i) - { - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - } - - HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_)); - HIP_CHECK_ERROR(hipEventSynchronize(stop)); - - float total_time = 0; - - HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); - - return total_time / nrepeat; - } - else - { - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - return 0; - } -#else - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - return 0; -#endif -} - -template -CK_TILE_HOST float launch_and_time_kernel_with_preprocess(const stream_config& s, - PreProcessFunc preprocess, - F kernel, - dim3 grid_dim, - dim3 block_dim, - std::size_t lds_byte, - Args... args) -{ -#if CK_TILE_TIME_KERNEL - if(s.time_kernel_) - { -#if CK_TILE_DEBUG_LOG - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - printf("Warm up 1 time\n"); -#endif - // warm up - preprocess(); - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - - const int nrepeat = 10; -#if CK_TILE_DEBUG_LOG - printf("Start running %d times...\n", nrepeat); -#endif - hipEvent_t start, stop; - - HIP_CHECK_ERROR(hipEventCreate(&start)); - HIP_CHECK_ERROR(hipEventCreate(&stop)); - - HIP_CHECK_ERROR(hipDeviceSynchronize()); - HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_)); - - for(int i = 0; i < nrepeat; ++i) - { - preprocess(); - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - } - - HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_)); - HIP_CHECK_ERROR(hipEventSynchronize(stop)); - - float total_time = 0; - - HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop)); - - return total_time / nrepeat; - } - else - { - preprocess(); - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - - return 0; - } -#else - kernel<<>>(args...); - hip_check_error(hipGetLastError()); - - return 0; -#endif + Kernel{}(args...); } +// +// return a anonymous functor(lambda) to be called later +// the KernelImpl should be a class without non-static data member, or let's say +// can be instantiate with "KernelImpl{}" +// +// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl +// template -CK_TILE_HOST float launch_kernel(const stream_config& s, - KernelImpl kernel_impl, - dim3 grid_dim, - dim3 block_dim, - std::size_t dynamic_smem_byte, - Args... args) +CK_TILE_HOST auto +make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { const auto kernel = kentry; - return launch_and_time_kernel( - s, kernel, grid_dim, block_dim, dynamic_smem_byte, kernel_impl, args...); + return [=](const stream_config& s) { + kernel<<>>(args...); + }; } + +// clang-format off +/* + * launch_kernel() + * + * this is the function to launch arbitrary number of kernels with optional timer(selected by stream_config) + * the callables should have signature as "operator()(const stream_config& s){ ... }" to call + * + * the simplest way is pass in a lambda function, with "[=](const stream_config& s){ call_your_kernel_here() }" + * as signature, for the callable (pay attention to the capture list) + * + * e.g. + * ck_tile::launch_kernel(s, + * [=](const stream_config& s){ hipMemset(ptr, 0, size) }, + * [=](const stream_config& s){ some_kernel<<>>(arg); } + * ); + * + * if you use ck_tile kernel, or similiar to this style (structure with "static __device__ operator()(...){}") + * you can pass your kernel to ck_tile::make_kernel(), which will create a anonymous functor for you, + * then pass it to ck_tile::launch_kernel() + * + * e.g. + * ck_tile::launch_kernel(s, + * ck_tile::make_kernel(kernel_0{}, grids0, blocks0, 0, kargs0), + * ck_tile::make_kernel(kernel_1{}, grids1, blocks1, 0, kargs1), + * ...); + **/ +// clang-format on +template +CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables) +{ + // clang-format off + if(!s.time_kernel_) { + (callables(s),...); hip_check_error(hipGetLastError()); + return 0; + } + if(s.is_gpu_timer_) { + gpu_timer timer {}; + + // warmup + for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + + timer.start(s.stream_id_); + for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + timer.stop(s.stream_id_); + + return timer.duration() / s.nrepeat_; + } + else { + cpu_timer timer {}; + + // warmup + for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + + timer.start(s.stream_id_); + for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); + timer.stop(s.stream_id_); + + return timer.duration() / s.nrepeat_; + } + // clang-format on +} + } // namespace ck_tile diff --git a/include/ck_tile/host/stream_config.hpp b/include/ck_tile/host/stream_config.hpp index d29c6f0fa1..47cf0fd5e4 100644 --- a/include/ck_tile/host/stream_config.hpp +++ b/include/ck_tile/host/stream_config.hpp @@ -6,6 +6,22 @@ #include namespace ck_tile { +/* + * construct this structure with behavior as: + * + * // create stream config with default stream(NULL), and not timing the kernel + * stream_config s = stream_config{}; + * + * // create stream config with _some_stream_id_, and not timing the kernel + * stream_config s = stream_config{_some_stream_id_}; + * + * // create stream config with _some_stream_id_, and benchmark with warmup/repeat as default + * stream_config s = stream_config{_some_stream_id_, true}; + * + * // create stream config with _some_stream_id_, and benchmark using cpu timer + * stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, false}; + **/ + struct stream_config { hipStream_t stream_id_ = nullptr; @@ -13,5 +29,6 @@ struct stream_config int log_level_ = 0; int cold_niters_ = 3; int nrepeat_ = 10; + bool is_gpu_timer_ = true; // keep compatible }; } // namespace ck_tile diff --git a/include/ck_tile/host/timer.hpp b/include/ck_tile/host/timer.hpp new file mode 100644 index 0000000000..e2baeaef7c --- /dev/null +++ b/include/ck_tile/host/timer.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include +#include +#include + +namespace ck_tile { + +struct gpu_timer +{ + CK_TILE_HOST gpu_timer() + { + HIP_CHECK_ERROR(hipEventCreate(&start_evt)); + HIP_CHECK_ERROR(hipEventCreate(&stop_evt)); + } + + CK_TILE_HOST ~gpu_timer() noexcept(false) + { + HIP_CHECK_ERROR(hipEventDestroy(start_evt)); + HIP_CHECK_ERROR(hipEventDestroy(stop_evt)); + } + + CK_TILE_HOST void start(const hipStream_t& s) + { + HIP_CHECK_ERROR(hipDeviceSynchronize()); + HIP_CHECK_ERROR(hipEventRecord(start_evt, s)); + } + + CK_TILE_HOST void stop(const hipStream_t& s) + { + HIP_CHECK_ERROR(hipEventRecord(stop_evt, s)); + HIP_CHECK_ERROR(hipEventSynchronize(stop_evt)); + } + // return in ms + CK_TILE_HOST float duration() const + { + float ms = 0; + HIP_CHECK_ERROR(hipEventElapsedTime(&ms, start_evt, stop_evt)); + return ms; + } + + private: + hipEvent_t start_evt, stop_evt; +}; + +struct cpu_timer +{ + // torch.utils.benchmark.Timer(), there is a sync inside each timer callback + CK_TILE_HOST void start(const hipStream_t&) + { + HIP_CHECK_ERROR(hipDeviceSynchronize()); + start_tick = std::chrono::high_resolution_clock::now(); + } + // torch.utils.benchmark.Timer(), there is a sync inside each timer callback + CK_TILE_HOST void stop(const hipStream_t&) + { + HIP_CHECK_ERROR(hipDeviceSynchronize()); + stop_tick = std::chrono::high_resolution_clock::now(); + } + // return in ms + CK_TILE_HOST float duration() const + { + double sec = + std::chrono::duration_cast>(stop_tick - start_tick) + .count(); + return static_cast(sec * 1e3); + } + + private: + std::chrono::time_point start_tick; + std::chrono::time_point stop_tick; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_position_encoding.hpp b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp index 2360b7bebc..c2fdaf3a1a 100644 --- a/include/ck_tile/ops/fmha/block/block_position_encoding.hpp +++ b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp @@ -23,13 +23,13 @@ VERTICAL: [0] 1 2 3 4 5 [0] 1 2 3 4 5 -TOP_LEFT: +TOP_LEFT(but negative): [0] 1 2 3 4 5 1 [0] 1 2 3 4 2 1 [0] 1 2 3 3 2 1 [0] 1 2 -FROM_BOTTOM_RIGHT: +FROM_BOTTOM_RIGHT(but negative): 2 1 [0] 1 2 3 3 2 1 [0] 1 2 4 3 2 1 [0] 1 diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 86a13bce89..6624bf1a9b 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -79,7 +79,7 @@ struct FmhaFwdKernel return n.empty() ? n : std::string("p") + n; }(); return _SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s::name) + - "_" + (kIsGroupMode ? "group" : "batch") + "_" + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_" "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp index 214ecb7b7c..2dca84b786 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp @@ -18,6 +18,8 @@ struct FmhaFwdTilePartitioner static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; + static constexpr const char* name = "shb"; + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, @@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner } }; +template +using FmhaFwdTilePartitioner_SHB = FmhaFwdTilePartitioner; + +template +struct FmhaFwdTilePartitioner_HBS +{ + using BlockFmhaShape = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0; + static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0; + static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0; + static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1; + static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1; + + static constexpr const char* name = "hbs"; + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + // TODO: this may need tuning + return dim3(nhead_, + batch_size_, + ck_tile::integer_divide_ceil(seqlen_q_, kM0) * + ck_tile::integer_divide_ceil(hdim_v_, kN1)); + } + + CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1); + + const index_t i_block = blockIdx.z; + const index_t i_nhead = blockIdx.x; + const index_t i_batch = blockIdx.y; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index cb250516f4..dd164e72ea 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -36,8 +36,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; @@ -49,8 +48,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) return bit_cast( __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); #else @@ -89,8 +87,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); #else ck_tile::ignore = c_vec; @@ -102,8 +99,7 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) +#if defined(__gfx9__) return bit_cast( __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); #else @@ -143,7 +139,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); #elif defined(__gfx908__) static_for<0, 2, 1>{}([&](auto k) { @@ -167,7 +163,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) return bit_cast( __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); #elif defined(__gfx908__) @@ -220,7 +216,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); #elif defined(__gfx908__) static_for<0, 2, 1>{}([&](auto k) { @@ -244,7 +240,7 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) return bit_cast( __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); #elif defined(__gfx908__) @@ -299,7 +295,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); @@ -333,7 +329,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp index 8120eff250..77d3728430 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp @@ -35,14 +35,24 @@ template + ConvolutionBackwardWeightSpecialization ConvSpec, + BlockGemmPipelineScheduler Scheduler, + BlockGemmPipelineVersion PipelineVersion> using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std::tuple< // clang-format off - //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| - //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| - //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| - //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | - DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 1>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 8, 1, 8>, 1> + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumBatch| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 32, 8, 32, 32, 1, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 32, 8, 32, 32, 1, 4, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8>, + + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 32, 32, 1, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 2>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 32, 8, 32, 32, 2, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 4>, + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 32, 8, 32, 32, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, S<4, 4, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 8> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index 91b7df3d45..5a703e5814 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -352,7 +352,9 @@ struct DeviceOperationInstanceFactory>>& instances); -void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances( +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances( std::vector>>& instances); -void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances( std::vector + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ck::Tuple<>, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple<>, + EDataType, + AElementOp, + Multiply, + PassThrough>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + PassThrough, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + PassThrough, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..6833ab20f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + Multiply, + Add>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Multiply, + Add, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Multiply, + Add, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 98546de040..7cbf55e5f9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include @@ -52,112 +52,6 @@ void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( Interwave>{}); } -void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( - std::vector, - ck::Tuple, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple, - EDataType, - AElementOp, - Multiply, - Add>>>& instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - Multiply, - Add, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - Multiply, - Add, - GemmMNKPadding, - Interwave>{}); -} - -void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances( - std::vector, - ck::Tuple<>, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple<>, - EDataType, - AElementOp, - Multiply, - PassThrough>>>& instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple<>, - Multiply, - PassThrough, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple<>, - Multiply, - PassThrough, - GemmMNKPadding, - Interwave>{}); -} - -void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( - std::vector, - ck::Tuple<>, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple<>, - EDataType, - AElementOp, - Multiply, - FastGelu>>>& instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple<>, - Multiply, - FastGelu, - GemmMNKPadding, - Interwave>{}); - - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< - ck::Tuple, - ck::Tuple<>, - ck::Tuple, - ck::Tuple<>, - Multiply, - FastGelu, - GemmMNKPadding, - Interwave>{}); -} - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..044fc26721 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple<>, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple<>, + EDataType, + AElementOp, + Multiply, + FastGelu>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + FastGelu, + GemmMNKPadding, + Interwave>{}); + + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + FastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..12bcf19252 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + Multiply, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + Multiply, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..e4a04d48b3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + MultiplyAdd>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyAdd, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyAdd, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp index 5c46730ea7..590e89284f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include @@ -52,111 +52,6 @@ void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_i Interwave>{}); } -void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( - std::vector, - ck::Tuple, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple, - EDataType, - AElementOp, - PassThrough, - MultiplyAdd>>>& instances) -{ - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - MultiplyAdd, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances(instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - MultiplyAdd, - GemmMNKPadding, - Interwave>{}); -} - -void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances( - std::vector, - ck::Tuple, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple, - EDataType, - AElementOp, - PassThrough, - Multiply>>>& instances) -{ - add_device_operation_instances( - instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - Multiply, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances( - instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - Multiply, - GemmMNKPadding, - Interwave>{}); -} - -void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( - std::vector, - ck::Tuple, - ELayout, - AsDataType, - ck::Tuple, - ck::Tuple, - EDataType, - AElementOp, - PassThrough, - MultiplyFastGelu>>>& instances) -{ - add_device_operation_instances( - instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - MultiplyFastGelu, - GemmMNKPadding, - Interwave>{}); - add_device_operation_instances( - instances, - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, - ck::Tuple, - ck::Tuple, - ck::Tuple, - PassThrough, - MultiplyFastGelu, - GemmMNKPadding, - Interwave>{}); -} - } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..5741ee29ac --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" + +#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + MultiplyFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyFastGelu, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances( + instances, + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyFastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index a21b7702b9..f730534c8d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -6,7 +6,9 @@ set(GROUPED_CONV2D_BWD_WEIGHT xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp - xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp) + xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp + xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp + ) if(DL_KERNELS) list(APPEND GROUPED_CONV2D_BWD_WEIGHT diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp similarity index 81% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp index ef583cf4fd..15401f0e1b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp @@ -10,7 +10,7 @@ namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances( +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances( std::vector{}); - // 2. Filter1x1Stride1Pad0 - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< - 2, - NHWGC, - GKYXC, - NHWGK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp new file mode 100644 index 0000000000..398c14b11c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 435d1831e5..8e939c15a9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -1,12 +1,14 @@ -# XDL_DL_WMMA_KERNELS + # XDL_DL_WMMA_KERNELS set(GROUPED_CONV3D_BWD_WEIGHT - xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp - xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp) + xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp + ) if(DL_KERNELS) list(APPEND GROUPED_CONV3D_BWD_WEIGHT diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp similarity index 81% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp index c4849c017e..4d0f1e68cb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp @@ -10,7 +10,7 @@ namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances( std::vector{}); - // 2. Filter1x1Stride1Pad0 - add_device_operation_instances( - instances, - device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< - 3, - NDHWGC, - GKZYXC, - NDHWGK, - ConvBwdWeightFilter1x1Stride1Pad0>{}); + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v2>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp new file mode 100644 index 0000000000..c5cc062f2a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp index 80c1c42b83..09e03de99c 100644 --- a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp @@ -88,7 +88,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index 476ec37eb2..0b73e4fcd1 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -87,7 +87,7 @@ bool profile_grouped_gemm_impl(int do_verification, c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i diff --git a/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp index 33e758f406..74faf15be3 100644 --- a/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp @@ -82,7 +82,7 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification, Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i diff --git a/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp index feb0be87e7..14df96d505 100644 --- a/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp @@ -88,7 +88,7 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification, c_m_n_host_results.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); - if(ck::EnvIsEnabled(ENV(CK_LOGGING))) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..8e7e8607ba --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "rocm-composable-kernel" +dynamic = ["version"] +description = "Composable Kernel, performance-critical kernels for machine learning workloads" +readme = "README.md" +requires-python = ">=3.8" +license = {file = "LICENSE"} +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] +dependencies = [] + +[project.urls] +"Homepage" = "https://github.com/rocm/composable_kernel" +"Bug Tracker" = "https://github.com/rocm/composable_kernel/issues" + +[tool.setuptools] +packages = ["ck4inductor", "ck4inductor.include", "ck4inductor.library"] + +[tool.setuptools.package-dir] +ck4inductor = "python/ck4inductor" +"ck4inductor.include" = "include" +"ck4inductor.library" = "library" + +[tool.setuptools.package-data] +"ck4inductor.include" = ["ck/**/*.hpp"] +"ck4inductor.library" = ["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"] + +[tool.setuptools.dynamic] +version = { attr = "setuptools_scm.get_version" } diff --git a/python/ck4inductor/__init__.py b/python/ck4inductor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/ck4inductor/universal_gemm/gen_instances.py b/python/ck4inductor/universal_gemm/gen_instances.py new file mode 100644 index 0000000000..8b6d6b73b2 --- /dev/null +++ b/python/ck4inductor/universal_gemm/gen_instances.py @@ -0,0 +1,570 @@ +import logging +import os +import subprocess +from dataclasses import fields, replace +from functools import lru_cache, partial +from typing import List + +from ..util import library_path + +from .op import CKGemmOperation + +log = logging.getLogger(__name__) + + +def _ck_library_dir(): + gemm_instances_path = os.path.join( + library_path(), "src", "tensor_operation_instance", "gpu", "gemm_universal" + ) + if not os.path.exists(gemm_instances_path): + log.error("CK library path %s does not exist", gemm_instances_path) + return None + return gemm_instances_path + + +def parse_instances(str_instances: List[str]) -> List[CKGemmOperation]: + """ + Parse the lines containing Universal Gemm template instances into `CKGemmOperation` instances + """ + + def maybe_int(s): + try: + return int(s) + except ValueError: + return s + + op_instances = [] + for line in str_instances: + s_template_args = line.split("DeviceGemm_Xdl_CShuffleV3")[-1].strip("<>, ") + template_args = [] + i_current = 0 + while i_current < len(s_template_args): + if s_template_args[i_current] == " ": + # skip whitespace + i_current += 1 + continue + elif s_template_args[i_current : i_current + 2] == "S<": + # parse template S + i_next = s_template_args.find(">", i_current) + template_args.append( + tuple(map(int, s_template_args[i_current + 2 : i_next].split(","))) + ) + i_current = i_next + 2 + else: + # all string attributes must be either type aliases or global constants in C++ + i_next = s_template_args.find(",", i_current) + template_args.append( + maybe_int( + s_template_args[i_current : i_next if i_next != -1 else None] + ) + ) + if i_next != -1: + i_current = i_next + 1 + if i_next == -1: + break + # pad with `None`s for the fields which are not defined in the instance + new_instance = CKGemmOperation( + *template_args, # type: ignore[arg-type] + *((None,) * (len(fields(CKGemmOperation)) - len(template_args))), + ) + # the last 2 template parameters are optional + # if they are absent, substitute them with default values from Universal Gemm C++ template declaration + if new_instance.a_compute_dtype is None: + new_instance.a_compute_dtype = new_instance.c_element_dtype + if new_instance.b_compute_dtype is None: + new_instance.b_compute_dtype = new_instance.c_element_dtype + + op_instances.append(new_instance) + return op_instances + + +def default_instances() -> List[CKGemmOperation]: + # fallback: known working op instance for problem size M=2240 K=256 N=2048 + # all string attributes must be either type aliases or global constants in C++ + + return [ + CKGemmOperation( + a_layout="Row", + b_layout="Row", + c_layout="Row", + a_element_dtype="F16", + b_element_dtype="F16", + c_element_dtype="F16", + a_compute_dtype="F16", + b_compute_dtype="F16", + acc_dtype="F32", + c_shuffle_dtype="F16", + a_elementwise_op="PassThrough", + b_elementwise_op="PassThrough", + c_elementwise_op="PassThrough", + gemm_specialization="GemmSpecialization::Default", + block_size=256, + m_per_block=224, + n_per_block=256, + k_per_block=64, + a_k1=8, + b_k1=2, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=7, + n_xdl_per_wave=8, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1), + a_block_transfer_thread_cluster_arrange_order=(1, 0, 2), + a_block_transfer_src_access_order=(1, 0, 2), + a_block_transfer_src_vector_dim=2, + a_block_transfer_src_scalar_per_vector=8, + a_block_transfer_dst_scalar_per_vector_ak1=8, + a_block_lds_extra_m=0, # type: ignore[arg-type] + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1), + b_block_transfer_thread_cluster_arrange_order=(0, 2, 1), + b_block_transfer_src_access_order=(0, 2, 1), + b_block_transfer_src_vector_dim=1, + b_block_transfer_src_scalar_per_vector=8, + b_block_transfer_dst_scalar_per_vector_bk1=2, + b_block_lds_extra_n=0, # type: ignore[arg-type] + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=2, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ) + ] + + +@lru_cache(None) +def gen_ops_library() -> List[CKGemmOperation]: + """ + Parse the Universal Gemm instances defined in the composable kernel library folder. + """ + ck_library_dir = _ck_library_dir() + if not ck_library_dir: + return [] + + grep_result = subprocess.run( + [ + "grep", + "-inR", + "DeviceGemm_Xdl_CShuffleV3", + _ck_library_dir(), + ], + capture_output=True, + text=True, + ) + + op_instances = parse_instances(grep_result.stdout.strip().split("\n")) + + log.debug("ck instances from library: %d", len(op_instances)) + + schedulers = [ + "BlockGemmPipelineScheduler::Intrawave", + "BlockGemmPipelineScheduler::Interwave", + ] + gemm_specs = [ + "GemmSpecialization::Default", + "GemmSpecialization::MPadding", + "GemmSpecialization::NPadding", + "GemmSpecialization::KPadding", + "GemmSpecialization::MNPadding", + "GemmSpecialization::MKPadding", + "GemmSpecialization::NKPadding", + "GemmSpecialization::MNKPadding", + ] + + # substitute templated args by looping through their domains + substitute_instances = [] + for instance in op_instances: + sub_scheduler = instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched" + sub_spec = instance.gemm_specialization == "GemmSpec" + schedulers_range = ( + schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler] + ) + spec_range = gemm_specs if sub_spec else [instance.gemm_specialization] + for scheduler in schedulers_range: + for spec in spec_range: + substitute_instances.append( + replace( + instance, + block_gemm_pipeline_scheduler=scheduler, + gemm_specialization=spec, + ) + ) + + return substitute_instances + + +@lru_cache(None) +def gen_ops_preselected() -> List[CKGemmOperation]: + """ + Manually selected (through benchmarking) F16/F16/F16 Row/Col/Row instances + """ + ck_gemm_f16_rcr = partial( + CKGemmOperation, + a_layout="Row", + b_layout="Col", + c_layout="Row", + a_element_dtype="F16", + b_element_dtype="F16", + c_element_dtype="F16", + acc_dtype="F32", + c_shuffle_dtype="F16", + a_elementwise_op="PassThrough", + b_elementwise_op="PassThrough", + c_elementwise_op="PassThrough", + k_per_block=64, + a_k1=8, + b_k1=8, + a_block_transfer_thread_cluster_arrange_order=(1, 0, 2), + a_block_transfer_src_access_order=(1, 0, 2), + a_block_transfer_src_vector_dim=2, + a_block_transfer_src_scalar_per_vector=8, + a_block_transfer_dst_scalar_per_vector_ak1=8, + a_block_lds_extra_m=0, + b_block_transfer_thread_cluster_arrange_order=(1, 0, 2), + b_block_transfer_src_access_order=(1, 0, 2), + b_block_transfer_src_vector_dim=2, + b_block_transfer_src_scalar_per_vector=8, + b_block_transfer_dst_scalar_per_vector_bk1=8, + b_block_lds_extra_n=0, + a_compute_dtype="F16", + b_compute_dtype="F16", + ) + ck_gemm_f16_rcr_compute_friendly = partial( + ck_gemm_f16_rcr, + block_size=256, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 32, 1), + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 32, 1), + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ) + ck_gemm_f16_rcr_memory_friendly = partial( + ck_gemm_f16_rcr, + block_size=128, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1), + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1), + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Interwave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v2", + ) + ck_gemm_f16_rcr_latency_friendly = partial( + ck_gemm_f16_rcr, + gemm_specialization="GemmSpecialization::Default", + block_size=128, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + a_block_transfer_thread_cluster_lengths_ak0_m_ak1=(8, 16, 1), + b_block_transfer_thread_cluster_lengths_bk0_n_bk1=(8, 16, 1), + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v1", + ) + return [ + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=224, + n_per_block=256, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=7, + n_xdl_per_wave=8, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=2, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v4", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v5", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v3", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v4", + ), + ck_gemm_f16_rcr_compute_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=128, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + block_gemm_pipeline_scheduler="BlockGemmPipelineScheduler::Intrawave", + block_gemm_pipeline_version="BlockGemmPipelineVersion::v5", + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=16, + n_per_block=32, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=16, + n_per_block=32, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=16, + n_per_block=64, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=2, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=32, + n_per_block=64, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=32, + n_per_block=128, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=1, + n_xdl_per_wave=2, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::Default", + m_per_block=32, + n_per_block=16, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=32, + n_per_block=16, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=4, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=64, + n_per_block=16, + m_per_xdl=16, + n_per_xdl=16, + m_xdl_per_wave=2, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=2, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 64, + 1, + 2, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=64, + n_per_block=32, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=1, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=1, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_memory_friendly( + gemm_specialization="GemmSpecialization::MNKPadding", + m_per_block=128, + n_per_block=32, + m_per_xdl=32, + n_per_xdl=32, + m_xdl_per_wave=2, + n_xdl_per_wave=1, + c_shuffle_m_xdl_per_wave_per_shuffle=2, + c_shuffle_n_xdl_per_wave_per_shuffle=1, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + c_shuffle_block_transfer_scalar_per_vector_n_per_block=8, + ), + ck_gemm_f16_rcr_latency_friendly( + m_per_block=16, + n_per_block=32, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 16, + 1, + 8, + ), + ), + ck_gemm_f16_rcr_latency_friendly( + m_per_block=32, + n_per_block=16, + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block=( + 1, + 32, + 1, + 4, + ), + ), + ] + + +if __name__ == "__main__": + print(gen_ops_library()) diff --git a/python/ck4inductor/universal_gemm/op.py b/python/ck4inductor/universal_gemm/op.py new file mode 100644 index 0000000000..ab541c5fb9 --- /dev/null +++ b/python/ck4inductor/universal_gemm/op.py @@ -0,0 +1,95 @@ +from dataclasses import asdict, dataclass +from typing import Optional, Tuple + + +@dataclass +class CKGemmOperation: + """ + A python dataclass storing the template parameters of a CK Universal Gemm template instance + """ + + a_layout: str + b_layout: str + c_layout: str + + a_element_dtype: str + b_element_dtype: str + c_element_dtype: str + + acc_dtype: str + c_shuffle_dtype: str + + a_elementwise_op: str + b_elementwise_op: str + c_elementwise_op: str + + gemm_specialization: str + + block_size: int + + m_per_block: int + n_per_block: int + k_per_block: int + + a_k1: int + b_k1: int + + m_per_xdl: int + n_per_xdl: int + + m_xdl_per_wave: int + n_xdl_per_wave: int + + a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int] + a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int] + a_block_transfer_src_access_order: Tuple[int, int, int] + a_block_transfer_src_vector_dim: int + a_block_transfer_src_scalar_per_vector: int + a_block_transfer_dst_scalar_per_vector_ak1: int + a_block_lds_extra_m: bool + + b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int] + b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int] + b_block_transfer_src_access_order: Tuple[int, int, int] + + b_block_transfer_src_vector_dim: int + b_block_transfer_src_scalar_per_vector: int + b_block_transfer_dst_scalar_per_vector_bk1: int + b_block_lds_extra_n: bool + + c_shuffle_m_xdl_per_wave_per_shuffle: int + c_shuffle_n_xdl_per_wave_per_shuffle: int + + c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: ( + Tuple[int, int, int, int] + ) + c_shuffle_block_transfer_scalar_per_vector_n_per_block: int + + block_gemm_pipeline_scheduler: str + block_gemm_pipeline_version: Optional[str] + + a_compute_dtype: Optional[str] + b_compute_dtype: Optional[str] + + def name(self): + # cpp alias for template instance + return f"ck_devicegemm_xdl_shuffle_v3_{self.key_name()}" + + def key_name(self): + # TBD; must be unique per instance. Intended to use as dict key + return "_".join( + [ + "K" + + field_name.replace("_", "").lower() + + "V" + + ( + "x".join(map(str, iter(field_value))) + if isinstance(field_value, tuple) + else str(field_value).replace(":", "") + ) + for field_name, field_value in self.dict_items() + ] + ) + + def dict_items(self): + return asdict(self).items() diff --git a/python/ck4inductor/util.py b/python/ck4inductor/util.py new file mode 100644 index 0000000000..79d6be00f3 --- /dev/null +++ b/python/ck4inductor/util.py @@ -0,0 +1,7 @@ +import functools +import os + + +@functools.lru_cache(None) +def library_path(): + return os.path.join(os.path.dirname(__file__), 'library') diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 25c63ac7fe..49b67992b1 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -40,6 +40,13 @@ function(add_test_executable TEST_NAME) endif() endforeach() endif() + + if(INSTANCES_ONLY) + set(TEST_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(TEST_TARGETS ${GPU_TARGETS}) + endif() + foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl test ${source} ") @@ -47,20 +54,27 @@ function(add_test_executable TEST_NAME) endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") + if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") message("removing xdl test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") message("removing wmma test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() #only continue if there are some source files left on the list if(ARGN) + if(ARGN MATCHES "_xdl") + list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) + elseif(ARGN MATCHES "_wmma") + list(REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + endif() + set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) + set_property(TARGET ${TEST_NAME} PROPERTY HIP_ARCHITECTURES ${TEST_TARGETS} ) target_link_libraries(${TEST_NAME} PRIVATE getopt::getopt) add_test(NAME ${TEST_NAME} COMMAND $) add_dependencies(tests ${TEST_NAME}) @@ -105,6 +119,13 @@ function(add_gtest_executable TEST_NAME) endif() endforeach() endif() + + if(INSTANCES_ONLY) + set(TEST_TARGETS ${DEFAULT_GPU_TARGETS}) + else() + set(TEST_TARGETS ${GPU_TARGETS}) + endif() + foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl test ${source} ") @@ -112,20 +133,27 @@ function(add_gtest_executable TEST_NAME) endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") + if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") message("removing xdl test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() foreach(source IN LISTS ARGN) - if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") message("removing wmma test ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() #only continue if there are some source files left on the list if(ARGN) + if(ARGN MATCHES "_xdl") + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103) + elseif(ARGN MATCHES "_wmma") + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + endif() + set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) + set_property(TARGET ${TEST_NAME} PROPERTY HIP_ARCHITECTURES ${TEST_TARGETS} ) add_dependencies(tests ${TEST_NAME}) add_dependencies(check ${TEST_NAME}) diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 1c8082645c..5ef0730668 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -32,19 +32,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test std::vector conv_params; std::vector split_ks{1, 2}; - bool skip_case(const ck::utils::conv::ConvParam& params, const ck::index_t split_k) + bool skip_case(const ck::index_t split_k) { - // Odd K or C values are supported only by DL and WMMA - // kernels (only applies to fp16) - // DL and WMMA kernels currently support only `split_k=1` - if constexpr(std::is_same_v) - { - if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0)) - { - return true; - } - } - // 1d NWGC is only supported by DL kernel // DL kernel is only supported for split_k=1 if constexpr(std::is_same_v && std::is_same_v) @@ -100,7 +89,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test { for(auto& param : conv_params) { - if(!skip_case(param, split_k)) + if(!skip_case(split_k)) { pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_implconv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 16, 16, 1, 1, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}); this->Run(); } @@ -207,5 +198,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D) {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 16, 16, 1, 1, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->Run(); } diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index f47685cf91..55cb209772 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -6,6 +6,12 @@ if(result EQUAL 0) add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) endif() +add_gtest_executable(test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_two_stage_splitk) +endif() + add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) diff --git a/test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp b/test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp new file mode 100644 index 0000000000..67ecbaea30 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" + +#include "gtest/gtest.h" +#include "test_grouped_gemm_util.hpp" + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using RRR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage>; +using RCR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage>; +using RRR_F16_F16_F16_LargeK = + ck::test::TestGroupedGemmTwoStage>; +using RCR_F16_F16_F16_LargeK = + ck::test::TestGroupedGemmTwoStage>; +using RRR_BF16_BF16_BF16 = + ck::test::TestGroupedGemmTwoStage>; +using RCR_BF16_BF16_BF16 = + ck::test::TestGroupedGemmTwoStage>; +using RRR_BF16_I8_BF16 = + ck::test::TestGroupedGemmTwoStage>; +using RCR_BF16_I8_BF16 = + ck::test::TestGroupedGemmTwoStage>; + +const std::vector KBATCH{1, 2, 3, 5, 8}; + +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN, + RRR_F16_F16_F16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK, + RCR_F16_F16_F16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16, + RRR_BF16_BF16_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16, + RCR_BF16_BF16_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16_INT8, + RRR_BF16_I8_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16_INT8, + RCR_BF16_I8_BF16, + testing::ValuesIn(KBATCH)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_KN, + RRR_F16_F16_F16_LargeK, + testing::Values(32, 64)); +INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_NK, + RCR_F16_F16_F16_LargeK, + testing::Values(32, 64)); + +#include "test_grouped_gemm_ut_cases.inc" +#include "test_grouped_gemm_two_stage_ut_cases.inc" diff --git a/test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc b/test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc new file mode 100644 index 0000000000..40d48f4ec0 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_two_stage_ut_cases.inc @@ -0,0 +1,61 @@ +#pragma once + +TEST_P(RRR_BF16_BF16_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_BF16_BF16_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RRR_BF16_I8_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), N); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} + +TEST_P(RCR_BF16_I8_BF16, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + const std::vector StrideAs(Ms.size(), K); + const std::vector StrideBs(Ms.size(), K); + const std::vector StrideCs(Ms.size(), N); + + this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam()); +} diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index 50f423ada3..9e1395b9f8 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,6 +22,7 @@ #include "ck/utility/tuple.hpp" #include "ck/utility/number.hpp" #include "profiler/profile_grouped_gemm_impl.hpp" +#include "profiler/profile_grouped_gemm_two_stage_impl.hpp" namespace ck { namespace test { @@ -90,6 +91,58 @@ class TestGroupedGemm : public testing::TestWithParam } }; +template +class TestGroupedGemmTwoStage : public testing::TestWithParam +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using ELayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using EDataType = std::tuple_element_t<5, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + + void SetUp() override {} + + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) + { + bool pass = ck::profiler::profile_grouped_gemm_two_stage_impl(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + EXPECT_TRUE(pass); + } +}; + template (4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, - 1, 0, 1, 2, 3, 4, - 2, 1, 0, 1, 2, 3, - 3, 2, 1, 0, 1, 2}); + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5, + -1, 0, -1, -2, -3, -4, + -2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2}); - rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, - 1, 0, 1, 2, - 2, 1, 0, 1, - 3, 2, 1, 0, - 4, 3, 2, 1, - 5, 4, 3, 2}); + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0, + -4, -3, -2, -1, + -5, -4, -3, -2}); - rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, - 1, 0, 1, - 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); - rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, - 3, 2, 1, 0, 1, 2, - 4, 3, 2, 1, 0, 1, - 5, 4, 3, 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2, + -4, -3, -2, -1, 0, -1, + -5, -4, -3, -2, -1, 0}); - rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, - 1, 2, 3, 4, - 0, 1, 2, 3, - 1, 0, 1, 2, - 2, 1, 0, 1, - 3, 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5, + -1, -2, -3, -4, + 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0}); - rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, - 1, 0, 1, - 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5}); - rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, - 1, 0, 1, 2, 3, 4, - 2, 1, 0, 1, 2, 3, - 3, 2, 1, 0, 1, 2}); + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5, + -1, 0, -1, -2, -3, -4, + -2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2}); - rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, - 1, 0, 1, 2, - 2, 1, 0, 1, - 3, 2, 1, 0, - 4, 3, 2, 1, - 5, 4, 3, 2}); + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0, + -4, -3, -2, -1, + -5, -4, -3, -2}); - rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, - 1, 0, 1, - 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); - rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, - 3, 2, 1, 0, 1, 2, - 4, 3, 2, 1, 0, 1, - 5, 4, 3, 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3, + -3, -2, -1, 0, -1, -2, + -4, -3, -2, -1, 0, -1, + -5, -4, -3, -2, -1, 0}); - rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, - 1, 2, 3, 4, - 0, 1, 2, 3, - 1, 0, 1, 2, - 2, 1, 0, 1, - 3, 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5, + -1, -2, -3, -4, + 0, -1, -2, -3, + -1, 0, -1, -2, + -2, -1, 0, -1, + -3, -2, -1, 0}); - rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, - 1, 0, 1, - 2, 1, 0}); + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2, + -1, 0, -1, + -2, -1, 0}); rtn &= test_alibi_slope_generation(8, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625}); rtn &= test_alibi_slope_generation(16, {0.7071067811865476, 0.5, 0.35355339059327384, 0.25000000000000006, 0.17677669529663692,