diff --git a/Jenkinsfile b/Jenkinsfile index 132257ad80..48b4c805cd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -735,11 +735,11 @@ def process_results(Map conf=[:]){ //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.2;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true - 0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true + 0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true;RUN_CODEGEN_TESTS=true 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true - 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false - 0 13 * * * % BUILD_LEGACY_OS=true ''' : "" + 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false + 0 13 * * * % BUILD_LEGACY_OS=true''' : "" pipeline { agent none @@ -806,6 +806,10 @@ pipeline { name: "RUN_GROUPED_CONV_LARGE_CASES_TESTS", defaultValue: false, description: "Run the grouped conv large cases tests (default: OFF)") + booleanParam( + name: "RUN_CODEGEN_TESTS", + defaultValue: false, + description: "Run codegen tests (default: OFF)") booleanParam( name: "RUN_CK_TILE_FMHA_TESTS", defaultValue: false, @@ -926,7 +930,30 @@ pipeline { execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ make -j64 test_grouped_convnd_fwd_large_cases_xdl && \ ./bin/test_grouped_convnd_fwd_large_cases_xdl""" - } + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } + stage("Run Codegen Tests") + { + parallel + { + stage("Run Codegen Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_CODEGEN_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx90a")} + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ CXX=/opt/rocm/llvm/bin/clang++ cmake ../codegen && \ + make -j64 check""" + } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() @@ -951,7 +978,7 @@ pipeline { make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \ cd ../ && example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ - } + } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() @@ -970,7 +997,7 @@ pipeline { make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \ cd ../ && example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ - } + } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() @@ -995,7 +1022,7 @@ pipeline { make -j64 tile_example_gemm_basic && \ cd ../ && example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ - } + } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() @@ -1014,7 +1041,7 @@ pipeline { make -j64 tile_example_gemm_basic && \ cd ../ && example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ - } + } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() @@ -1040,7 +1067,7 @@ pipeline { -DCMAKE_CXX_FLAGS=" -O3 " \ -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """ execute_args = " " - } + } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name) cleanWs() @@ -1059,7 +1086,7 @@ pipeline { -DCMAKE_CXX_FLAGS=" -O3 " \ -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """ execute_args = " " - } + } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name) cleanWs() @@ -1140,7 +1167,7 @@ pipeline { -D CMAKE_BUILD_TYPE=Release \ -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \ -D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """ - } + } steps{ buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 2492804f28..1ca0d12821 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -1,3 +1,6 @@ +cmake_minimum_required(VERSION 3.16) +project(composable_kernel_host) + set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) @@ -5,56 +8,51 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) -add_compile_options(-std=c++17) -find_package(hip) -add_custom_target(codegen) +find_package(ROCM) +include(ROCMInstallTargets) +include(ROCMTest) -# add include directories -include_directories(BEFORE - ${PROJECT_BINARY_DIR}/include - ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/library/include - ${HIP_INCLUDE_DIRS} - ) +rocm_setup_version(VERSION 1.0) list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake) include(Embed) file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS - ${CK_ROOT}/include/ck/*.hpp) -#printouts fot debug purposes -#message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") -#message(STATUS "RELATIVE: ${CK_ROOT}/include") + ${CK_ROOT}/include/ck/*.hpp) +# printouts fot debug purposes +# message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") +# message(STATUS "RELATIVE: ${CK_ROOT}/include") add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) -file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) +add_compile_options(-std=c++17) -##message(STATUS "SOURCE_FILES: ${SOURCES}") +file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library add_library(ck_host STATIC ${SOURCES}) target_link_libraries(ck_host PRIVATE ck_headers) -set_target_properties(ck_host PROPERTIES - LINKER_LANGUAGE CXX - POSITION_INDEPENDENT_CODE ON) +set_target_properties(ck_host PROPERTIES + LINKER_LANGUAGE CXX + POSITION_INDEPENDENT_CODE ON) -target_include_directories(ck_host PUBLIC - $ - $ -) +# target_include_directories(ck_host PUBLIC +# $ +# ) add_executable(ck-template-driver driver/main.cpp) target_link_libraries(ck-template-driver ck_host) -rocm_install( +rocm_install_targets( TARGETS ck_host ck_headers - EXPORT ck_hostTargets + EXPORT ck_host_targets + INCLUDE include + PRIVATE +) +rocm_export_targets( + EXPORT ck_host_targets + NAMESPACE composable_kernel:: ) -rocm_install(EXPORT ck_hostTargets - FILE composable_kernelck_hostTargets.cmake - NAMESPACE composable_kernel:: - DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel) -rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) if(BUILD_TESTING) - add_subdirectory(test) + add_subdirectory(test) endif() + diff --git a/codegen/test/CMakeLists.txt b/codegen/test/CMakeLists.txt index 1de612e49a..48fde531da 100644 --- a/codegen/test/CMakeLists.txt +++ b/codegen/test/CMakeLists.txt @@ -1,23 +1,25 @@ list(APPEND CMAKE_PREFIX_PATH /opt/rocm) add_subdirectory(rtc) file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) -# do not build the tests when we build the library for various targets -if(NOT GPU_ARCHS) - foreach(TEST_SRC ${TEST_SRCS}) - set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP) - get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) - add_executable(codegen_test_${BASE_NAME} ${TEST_SRC}) - if(CK_USE_ALTERNATIVE_PYTHON) - target_link_options(codegen_test_${BASE_NAME} PRIVATE -lstdc++fs) - endif() - add_dependencies(codegen codegen_test_${BASE_NAME}) - add_dependencies(tests codegen_test_${BASE_NAME}) - add_dependencies(check codegen_test_${BASE_NAME}) - add_test(NAME codegen_test_${BASE_NAME} COMMAND codegen_test_${BASE_NAME}) - message("adding test codegen_test_${BASE_NAME}") - target_link_libraries(codegen_test_${BASE_NAME} ck_rtc ck_host) - target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/codegen/test/include) + +# TODO: These tests need to be refactored to remove dependency on main ck +# headers and device compilation. +set(TESTS_REQUIRE_DEVICE_COMPILE + grouped_conv_fwd_multiple_d_v1 + grouped_conv_fwd_multiple_d_v2 + grouped_conv_fwd_multiple_d_v3 + grouped_conv_fwd_multiple_d_v4 +) +find_package(hip) + +foreach(TEST_SRC ${TEST_SRCS}) + get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) + rocm_add_test_executable(codegen_test_${BASE_NAME} ${TEST_SRC}) + target_link_libraries(codegen_test_${BASE_NAME} ck_rtc ck_host) + target_include_directories(codegen_test_${BASE_NAME} PUBLIC include) + if(BASE_NAME IN_LIST TESTS_REQUIRE_DEVICE_COMPILE) + target_link_libraries(codegen_test_${BASE_NAME} hip::device) target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/include) target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/library/include) - endforeach() -endif() + endif() +endforeach() diff --git a/codegen/test/common.hpp b/codegen/test/include/common.hpp similarity index 90% rename from codegen/test/common.hpp rename to codegen/test/include/common.hpp index 7ea0b8cc83..6873b3b436 100644 --- a/codegen/test/common.hpp +++ b/codegen/test/include/common.hpp @@ -1,36 +1,26 @@ #pragma once #include "ck/host/headers.hpp" -#include "ck/host/stringutils.hpp" #include #include #include #include #include -#include #include #include #include #include -// NOLINTNEXTLINE -const char* const ck_content_wrapper = R"__ck__( -${content} -)__ck__"; - -template -inline std::string content_wrapper(P p) -{ - return ck::host::InterpolateString(ck_content_wrapper, - {{"content", std::string{p.data(), p.size()}}}); -} - inline std::vector create_headers_for_test() { auto ck_headers = ck::host::GetHeaders(); std::vector result; std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [](auto& p) { - return rtc::src_file{p.first, content_wrapper(p.second)}; + std::string content; + content.reserve(p.second.size() + 1); + content.push_back(' '); // We need a whitespace before the content for hipRTC to work + content.append(p.second.data(), p.second.size()); + return rtc::src_file{p.first, std::move(content)}; }); return result; } diff --git a/codegen/test/rtc/CMakeLists.txt b/codegen/test/rtc/CMakeLists.txt index a83574947d..2e7ceb5648 100644 --- a/codegen/test/rtc/CMakeLists.txt +++ b/codegen/test/rtc/CMakeLists.txt @@ -1,7 +1,9 @@ +find_package(hip) file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp) add_library(ck_rtc ${RTC_SOURCES}) target_include_directories(ck_rtc PUBLIC include) target_link_libraries(ck_rtc PUBLIC hip::host) +target_link_libraries(ck_rtc PUBLIC -lstdc++fs) option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON) if(USE_HIPRTC_FOR_CODEGEN_TESTS) diff --git a/codegen/test/rtc/include/rtc/compile_kernel.hpp b/codegen/test/rtc/include/rtc/compile_kernel.hpp index 0b5decc311..b8d29bd29a 100644 --- a/codegen/test/rtc/include/rtc/compile_kernel.hpp +++ b/codegen/test/rtc/include/rtc/compile_kernel.hpp @@ -1,9 +1,8 @@ #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL -#include #include -#include +#include #include namespace rtc { @@ -11,7 +10,7 @@ namespace rtc { struct src_file { src_file(std::filesystem::path p, std::string c) : path{std::move(p)}, content{std::move(c)} {} - CK::fs::path path; + fs::path path; std::string content; }; diff --git a/codegen/test/rtc/include/rtc/filesystem.hpp b/codegen/test/rtc/include/rtc/filesystem.hpp new file mode 100644 index 0000000000..3b94b84b9f --- /dev/null +++ b/codegen/test/rtc/include/rtc/filesystem.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef GUARD_TEST_HOST_RTC_FILESYSTEM_HPP +#define GUARD_TEST_HOST_RTC_FILESYSTEM_HPP + +#include +#include + +// clang-format off +#if defined(CPPCHECK) + #define RTC_HAS_FILESYSTEM 1 + #define RTC_HAS_FILESYSTEM_TS 1 +#elif defined(_WIN32) + #if _MSC_VER >= 1920 + #define RTC_HAS_FILESYSTEM 1 + #define RTC_HAS_FILESYSTEM_TS 0 + #elif _MSC_VER >= 1900 + #define RTC_HAS_FILESYSTEM 0 + #define RTC_HAS_FILESYSTEM_TS 1 + #else + #define RTC_HAS_FILESYSTEM 0 + #define RTC_HAS_FILESYSTEM_TS 0 + #endif +#elif defined(__has_include) + #if __has_include() && __cplusplus >= 201703L + #define RTC_HAS_FILESYSTEM 1 + #else + #define RTC_HAS_FILESYSTEM 0 + #endif + #if __has_include() && __cplusplus >= 201103L + #define RTC_HAS_FILESYSTEM_TS 1 + #else + #define RTC_HAS_FILESYSTEM_TS 0 + #endif +#else + #define RTC_HAS_FILESYSTEM 0 + #define RTC_HAS_FILESYSTEM_TS 0 +#endif +// clang-format on + +#if RTC_HAS_FILESYSTEM +#include +#elif RTC_HAS_FILESYSTEM_TS +#include +#else +#error "No filesystem include available" +#endif + +namespace rtc { + +#if RTC_HAS_FILESYSTEM +namespace fs = ::std::filesystem; +#elif RTC_HAS_FILESYSTEM_TS +namespace fs = ::std::experimental::filesystem; +#endif + +} // namespace rtc + +#endif // GUARD_RTC_FILESYSTEM_HPP_ diff --git a/codegen/test/rtc/include/rtc/tmp_dir.hpp b/codegen/test/rtc/include/rtc/tmp_dir.hpp index 0b4bf002c1..a0a2cb9b77 100644 --- a/codegen/test/rtc/include/rtc/tmp_dir.hpp +++ b/codegen/test/rtc/include/rtc/tmp_dir.hpp @@ -2,13 +2,13 @@ #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR #include -#include +#include namespace rtc { struct tmp_dir { - CK::fs::path path; + fs::path path; tmp_dir(const std::string& prefix = ""); void execute(const std::string& cmd) const; diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index a3377d6853..c35c11b670 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -1,10 +1,11 @@ -#include #include +#include #ifdef HIPRTC_FOR_CODEGEN_TESTS #include #include #endif #include +#include #include #include #include @@ -96,9 +97,9 @@ kernel clang_compile_kernel(const std::vector& srcs, compile_options o for(const auto& src : srcs) { - CK::fs::path full_path = td.path / src.path; - CK::fs::path parent_path = full_path.parent_path(); - CK::fs::create_directories(parent_path); + fs::path full_path = td.path / src.path; + fs::path parent_path = full_path.parent_path(); + fs::create_directories(parent_path); write_string(full_path.string(), src.content); if(src.path.extension().string() == ".cpp") { @@ -112,7 +113,7 @@ kernel clang_compile_kernel(const std::vector& srcs, compile_options o td.execute(compiler() + options.flags); auto out_path = td.path / out; - if(not CK::fs::exists(out_path)) + if(not fs::exists(out_path)) throw std::runtime_error("Output file missing: " + out); auto obj = read_buffer(out_path.string()); @@ -204,7 +205,7 @@ struct hiprtc_program } else { - headers.push_back(std::string(src.content.begin(), src.content.end())); + headers.push_back(std::move(src.content)); include_names.push_back(std::move(src.path)); } } diff --git a/codegen/test/rtc/src/tmp_dir.cpp b/codegen/test/rtc/src/tmp_dir.cpp index 659bbbe13f..4e89bc3539 100644 --- a/codegen/test/rtc/src/tmp_dir.cpp +++ b/codegen/test/rtc/src/tmp_dir.cpp @@ -31,10 +31,10 @@ std::string unique_string(const std::string& prefix) } tmp_dir::tmp_dir(const std::string& prefix) - : path(CK::fs::temp_directory_path() / + : path(fs::temp_directory_path() / unique_string(prefix.empty() ? "ck-rtc" : "ck-rtc-" + prefix)) { - CK::fs::create_directories(this->path); + fs::create_directories(this->path); } void tmp_dir::execute(const std::string& cmd) const @@ -43,6 +43,6 @@ void tmp_dir::execute(const std::string& cmd) const std::system(s.c_str()); } -tmp_dir::~tmp_dir() { CK::fs::remove_all(this->path); } +tmp_dir::~tmp_dir() { fs::remove_all(this->path); } } // namespace rtc diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index e3c8d72590..569afed256 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -1,4 +1,3 @@ - // SPDX-License-Identifier: MIT // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. @@ -282,7 +281,11 @@ int main(int argc, char* argv[]) using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; - using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + using CodegenGemmPolicy = ck_tile:: + UniversalGemmPipelineAgBgCrPolicy; + + using CodegenGemmPipeline = + ck_tile::GemmPipelineAGmemBGmemCRegV1; invoke_gemm) - { - if(KRaw_ % ABlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else if constexpr(is_same_v) - { - if(MRaw_ % ABlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else - { - return false; - } - - // check vector load of B - if constexpr(is_same_v) - { - if(NRaw_ % BBlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else if constexpr(is_same_v) - { - if(KRaw_ % BBlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else - { - return false; - } - - // check vector load of B1 - if constexpr(is_same_v) - { - if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else if constexpr(is_same_v) - { - if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else - { - return false; - } - - // check vector load of C - if constexpr(is_same_v) - { - if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) - { - return false; - } - } - else if constexpr(is_same_v) - { - if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) - { - return false; - } - } - else - { - return false; - } - - return true; - } - #ifndef __HIPCC_RTC__ static bool IsSupportedArgument(const Argument& arg) { @@ -861,268 +771,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return str.str(); } #endif - - template - struct Descriptor - { - template - static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc) - { - const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc); - - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); - - const auto AK0 = K / AK1; - - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - template - static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc) - { - const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc); - - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); - - const auto BK0 = K / BK1; - - return transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - template - static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc) - { - const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc); - - const auto N = b1_grid_desc_n_k.GetLength(I0); - const auto K = b1_grid_desc_n_k.GetLength(I1); - - const auto B1K0 = K / B1K1; - - return transform_tensor_descriptor( - b1_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - template - static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc) - { - return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc); - } - - using AGridDesc_AK0_M_AK1 = - remove_cvref_t; - using BGridDesc_BK0_N_BK1 = - remove_cvref_t; - using B1GridDesc_BK0_N_BK1 = - remove_cvref_t; - using CGridDesc_M_N = remove_cvref_t; - - // GridwiseGemm - using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< - ADataType, // TODO: distinguish A/B datatype - GemmAccDataType, - CShuffleDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - AccElementwiseOperation, - B1ElementwiseOperation, - CElementwiseOperation, - InMemoryDataOperationEnum::Set, - AGridDesc_AK0_M_AK1, - BGridDesc_BK0_N_BK1, - B1GridDesc_BK0_N_BK1, - CGridDesc_M_N, - NumGemmKPrefetchStage, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - Gemm1NPerBlock, - Gemm1KPerBlock, - AK1, - BK1, - B1K1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - Gemm1NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - true, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - true, - BBlockLdsExtraN, - B1BlockTransferThreadClusterLengths_BK0_N_BK1, - B1BlockTransferThreadClusterArrangeOrder, - B1BlockTransferSrcAccessOrder, - B1BlockTransferSrcVectorDim, - B1BlockTransferSrcScalarPerVector, - B1BlockTransferDstScalarPerVector_BK1, - false, - B1BlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock, - LoopSched, - matrix_padder.PadN, - MaskOutUpperTriangle>; - - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1; - B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1; - CGridDesc_M_N c_grid_desc_m_n; - C0MatrixMask c0_matrix_mask; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_descriptor_mblock_mperblock_nblock_nperblock; - - // element-wise op - AElementwiseOperation a_element_op; - BElementwiseOperation b_element_op; - B1ElementwiseOperation b1_element_op; - CElementwiseOperation c_element_op; - - bool has_main_k_block_loop = true; - bool is_valid = false; - - constexpr Descriptor(ADesc a, - BDesc b, - B1Desc b1, - CDesc c, - AElementwiseOperation a_element_op_, - BElementwiseOperation b_element_op_, - B1ElementwiseOperation b1_element_op_, - CElementwiseOperation c_element_op_) - : a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)}, - b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)}, - b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)}, - c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)}, - block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)}, - c_grid_descriptor_mblock_mperblock_nblock_nperblock{ - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n)}, - has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( - a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, - c0_matrix_mask{c.GetLength(I1)}, - a_element_op{a_element_op_}, - b_element_op{b_element_op_}, - b1_element_op{b1_element_op_}, - c_element_op{c_element_op_}, - is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - b1_grid_desc_bk0_n_bk1, - c_grid_desc_m_n, - block_2_ctile_map) and - IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1), - b_grid_desc_bk0_n_bk1.GetLength(I1), - a_grid_desc_ak0_m_ak1.GetLength(I0) * - a_grid_desc_ak0_m_ak1.GetLength(I2), - b1_grid_desc_bk0_n_bk1.GetLength(I1))} - { - } - - constexpr bool IsValid() const { return is_valid; } - }; - - template - static constexpr auto - make_descriptor(ADesc a, - BDesc b, - B1Desc b1, - CDesc c, - AElementwiseOperation a_element_op = AElementwiseOperation{}, - BElementwiseOperation b_element_op = BElementwiseOperation{}, - B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{}, - CElementwiseOperation c_element_op = CElementwiseOperation{}) - { - return Descriptor( - a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op); - } - - template - __device__ static void Run(const Desc& desc, - const float scale, - const ADataType* __restrict__ p_a_grid, - const ADataType* __restrict__ p_b_grid, - const ADataType* __restrict__ p_b1_grid, - CDataType* __restrict__ p_c_grid) - { -#ifndef __HIPCC_RTC__ - assert(desc.is_valid); -#endif - __shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()]; - AccElementwiseOperation acc_element_op{scale}; - - if(desc.has_main_k_block_loop) - { - Desc::GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_b1_grid, - p_c_grid, - p_shared_block, - desc.a_element_op, - desc.b_element_op, - acc_element_op, - desc.b1_element_op, - desc.c_element_op, - desc.a_grid_desc_ak0_m_ak1, - desc.b_grid_desc_bk0_n_bk1, - desc.b1_grid_desc_bk0_n_bk1, - desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, - desc.block_2_ctile_map, - desc.c0_matrix_mask); - } - else - { - Desc::GridwiseGemm::template Run( - p_a_grid, - p_b_grid, - p_b1_grid, - p_c_grid, - p_shared_block, - desc.a_element_op, - desc.b_element_op, - acc_element_op, - desc.b1_element_op, - desc.c_element_op, - desc.a_grid_desc_ak0_m_ak1, - desc.b_grid_desc_bk0_n_bk1, - desc.b1_grid_desc_bk0_n_bk1, - desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, - desc.block_2_ctile_map, - desc.c0_matrix_mask); - } - } }; } // namespace device diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 05f9ab31e0..2f4bbfb09f 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck/utility/enable_if.hpp" #include "ck/utility/statically_indexed_array.hpp" #ifdef __HIPCC_RTC__ @@ -204,7 +205,7 @@ struct scalar_type }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using type = d1_t; @@ -240,7 +241,7 @@ struct vector_type()>> __device__ int static err = 0; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -300,7 +301,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -370,7 +371,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -452,7 +453,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -546,7 +547,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -650,7 +651,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -766,7 +767,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -892,7 +893,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -1042,7 +1043,7 @@ struct non_native_vector_base // non-native vector_type implementation template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using type = d1_t; @@ -1077,7 +1078,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d2_t = non_native_vector_base; @@ -1137,7 +1138,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d2_t = non_native_vector_base; @@ -1207,7 +1208,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d2_t = non_native_vector_base; @@ -1289,7 +1290,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d2_t = non_native_vector_base; @@ -1383,7 +1384,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d2_t = non_native_vector_base; @@ -1487,7 +1488,7 @@ struct vector_type()>> }; template -struct vector_type()>> +struct vector_type()>> { using d1_t = T; using d2_t = non_native_vector_base; diff --git a/include/ck/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp index edef4a1257..fc5274f244 100644 --- a/include/ck/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/ck.hpp" +#include "data_type.hpp" #include "integral_constant.hpp" #include "number.hpp" #include "type.hpp" diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index a08c4f3811..a8bc27cdff 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -157,8 +157,11 @@ #endif #endif +// workaround for ROCm 6.2 and later #ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE -#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133 +#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133) || \ + (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 3 && HIP_VERSION_PATCH >= 42131) || \ + (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR > 3) #define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1 #else #define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0 diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 131729992b..8a13c0b060 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); auto k_lds_write_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); auto k_lds_read_window = make_tile_window(k_lds_write_window.get_bottom_tensor_view(), make_tuple(number{}, number{}), k_lds_write_window.get_window_origin(), - Policy::template MakeKRegSliceBlockDescriptor()); + Policy::template MakeKRegBlockDescriptor()); auto k_reg_tensor = make_static_distributed_tensor( Policy::template MakeKRegBlockDescriptor()); @@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); auto v_lds_write_window = - make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); auto v_lds_read_window = make_tile_window(v_lds_write_window.get_bottom_tensor_view(), make_tuple(number{}, number{}), v_lds_write_window.get_window_origin(), - Policy::template MakeVRegSliceBlockDescriptor()); - - auto v_reg_tensor = make_static_distributed_tensor( - Policy::template MakeVRegBlockDescriptor()); + Policy::template MakeVRegBlockDescriptor()); //------------------------------------------------------------------ // KT, Reg ->LDS ->Reg @@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor()); auto shuffled_k_lds_write_window = make_tile_window( - shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); auto kt_lds_read = make_tensor_view( kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor()); @@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR block_sync_lds(); - v_reg_tensor = load_tile(v_lds_read_window); + auto v_reg_tensor = load_tile(v_lds_read_window); block_sync_lds(); //---------------------------- Loop Load in ----------------------------// // Q: HBM ->Reg ->LDS @@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); auto q_lds_window = - make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); auto q_lds_read_window = make_tile_window(q_lds_window.get_bottom_tensor_view(), @@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor()); auto shuffled_q_lds_write_window = make_tile_window( - shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); auto qt_lds_read = make_tensor_view( qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor()); @@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); auto do_lds_window = - make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); auto do_lds_read_window = make_tile_window(do_lds_window.get_bottom_tensor_view(), @@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor()); auto shuffled_do_lds_write_window = make_tile_window( - shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); auto dot_read_lds = make_tensor_view( dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor()); @@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR index_t i_total_loops = 0; index_t seqlen_q_step = seqlen_q_start; - static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0"); + static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0"); static_assert(kM0 == kK1, "kM0 should equal to kK1"); - static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); + static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2"); static_assert(kM0 == kK3, "kM0 should equal to kK3"); constexpr index_t k4_loops = kN0 / kK4; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 3156e4a356..d1b6e6f85b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); auto k_lds_write_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); auto k_lds_read_window = make_tile_window(k_lds_write_window.get_bottom_tensor_view(), make_tuple(number{}, number{}), k_lds_write_window.get_window_origin(), - Policy::template MakeKRegSliceBlockDescriptor()); + Policy::template MakeKRegBlockDescriptor()); auto k_reg_tensor = make_static_distributed_tensor( Policy::template MakeKRegBlockDescriptor()); @@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); auto v_lds_write_window = - make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); auto v_lds_read_window = make_tile_window(v_lds_write_window.get_bottom_tensor_view(), make_tuple(number{}, number{}), v_lds_write_window.get_window_origin(), - Policy::template MakeVRegSliceBlockDescriptor()); - - auto v_reg_tensor = make_static_distributed_tensor( - Policy::template MakeVRegBlockDescriptor()); + Policy::template MakeVRegBlockDescriptor()); //------------------------------------------------------------------ // KT, Reg ->LDS ->Reg @@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor()); auto shuffled_k_lds_write_window = make_tile_window( - shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); auto kt_lds_read = make_tensor_view( kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor()); @@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP block_sync_lds(); - v_reg_tensor = load_tile(v_lds_read_window); + auto v_reg_tensor = load_tile(v_lds_read_window); //---------------------------- Loop Load in ----------------------------// // Q: HBM ->Reg ->LDS auto q_dram_window = @@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); auto q_lds_window = - make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); auto q_lds_read_window = make_tile_window(q_lds_window.get_bottom_tensor_view(), @@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor()); auto shuffled_q_lds_write_window = make_tile_window( - shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); auto qt_lds_read = make_tensor_view( qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor()); @@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); auto do_lds_window = - make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); auto do_lds_read_window = make_tile_window(do_lds_window.get_bottom_tensor_view(), @@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor()); auto shuffled_do_lds_write_window = make_tile_window( - shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); auto dot_read_lds = make_tensor_view( dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor()); @@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP index_t i_total_loops = 0; index_t seqlen_q_step = seqlen_q_start; - static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0"); + static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0"); static_assert(kM0 == kK1, "kM0 should equal to kK1"); - static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); + static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2"); static_assert(kM0 == kK3, "kM0 should equal to kK3"); constexpr index_t k4_loops = kN0 / kK4; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 0afad0446c..d353203e0e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -196,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using QDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType); constexpr index_t kMinVecLoad = 4 / sizeof(QDataType); @@ -215,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using KDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType); constexpr index_t kMinVecLoad = 4 / sizeof(KDataType); @@ -234,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using VDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType); constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize; @@ -254,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using OGradDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType); constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType); @@ -315,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; @@ -327,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; return total_pixels / GetAlignmentK(); @@ -338,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; @@ -376,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t K1 = GetAlignmentK(); constexpr index_t K0 = kKPerBlock / K1; @@ -399,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t K1 = GetAlignmentV(); constexpr index_t K0 = kKPerBlock / K1; @@ -422,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t K1 = GetAlignmentQ(); constexpr index_t K0 = kKPerBlock / K1; @@ -445,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t K1 = GetAlignmentOGrad(); constexpr index_t K0 = kKPerBlock / K1; @@ -816,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor() { constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); return MakeXLdsBlockDescriptor(); } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - constexpr auto k_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode); - - return k_block_dstr; - } - template CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor() { @@ -865,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; @@ -890,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor() { constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kVPack = GetSmemKPackV(); return MakeXLdsBlockDescriptor(); } - template - CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor() - { - using BlockGemm = remove_cvref_t())>; - constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); - using WarpGemm = remove_cvref_t())>; - - constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{}); - constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); - - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; - - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; - - constexpr auto v_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); - - return v_block_dstr; - } - template CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor() { @@ -940,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; @@ -966,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t K1 = GetAlignmentK(); constexpr index_t K0 = kKPerBlock / K1; @@ -1048,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() { constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackQ(); @@ -1092,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t K1 = GetAlignmentQ(); constexpr index_t K0 = kKPerBlock / K1; @@ -1255,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { // Hold full block data constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kKPack = GetSmemKPackOGrad(); @@ -1299,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t K1 = GetAlignmentOGrad(); constexpr index_t K0 = kKPerBlock / K1; @@ -1859,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0; static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim; static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim; + static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0; + static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2; static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4; static constexpr index_t WarpGemmM = @@ -1873,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy // Compute static constexpr index_t Gemm0MFMA = - kM0 * kN0 * kQKHeaddim / - (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); + kM0 * kN0 * kK0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); static constexpr index_t Gemm1MFMA = - kM0 * kN0 * kVHeaddim / - (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); - static constexpr index_t Gemm2MFMA = kN0 * kVHeaddim * kM0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); + static constexpr index_t Gemm2MFMA = + kM0 * kN0 * kK2 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); static constexpr index_t Gemm3MFMA = kN0 * kQKHeaddim * kM0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); @@ -1903,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ(); static constexpr index_t SGradT_LDS_READ_P1 = kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); - static constexpr index_t Q_LDS_READ = - kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ(); + static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ(); static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); static constexpr index_t SGradT_LDS_READ_P2 = kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); static constexpr index_t OGrad_LDS_READ = - kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad(); + kM0 * kK2 / kBlockSize / GetAlignmentOGrad(); static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); // LDS Write diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index dc5983e4d1..436d964c37 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -23,6 +23,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp new file mode 100644 index 0000000000..7044a53140 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -0,0 +1,424 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +namespace ck_tile { + +// UniversalGemm Policy +template +struct UniversalGemmPipelineAgBgCrPolicy +{ + using LayoutA = remove_cvref_t; + using LayoutB = remove_cvref_t; + using LayoutC = remove_cvref_t; + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + static constexpr bool TransposeC = true; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using WarpGemm = WarpGemmMfmaDispatcher; + + using ADataType = remove_cvref_t; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t K1 = WarpGemm::kK; + constexpr index_t K0 = KPerBlock / K1; + + if constexpr(std::is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple(K0 * number{}, number{}, K1), + make_tuple(K1, number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(K1)), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_kMLdsLayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(K0, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( + a_lds_block_desc_ak0_kMLdsLayer_m_ak1, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return a_lds_block_desc_m_k; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I0); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = Problem::kBlockSize / M0; + constexpr auto K0PerThreadWrite = K0 / KThreadWrite; + constexpr auto KThreadRead = 64 / WarpGemm::kM; + constexpr auto K0PerThreadRead = K0 / KThreadRead; + + constexpr auto kfold = + (K1 * M0 * sizeof(ADataType) > 128) ? 1 : 128 / (K1 * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=kN0 + constexpr auto mpair = (K1 * WarpGemm::kM * sizeof(ADataType) > 128) + ? 1 + : ((128 / (K1 * WarpGemm::kM * sizeof(ADataType))) > M0 + ? M0 + : 128 / (K1 * WarpGemm::kM * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + K1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_xor_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<1>{}, + sequence<2>{}, + sequence<0, 3>{}, + sequence<4, 5>{}, + sequence<6>{}, + sequence<7>{})); + + constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + K1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return a_lds_block_desc_m_k; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + using WarpGemm = WarpGemmMfmaDispatcher; + + using BDataType = remove_cvref_t; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = WarpGemm::kK; + constexpr index_t K0 = KPerBlock / K1; + + if constexpr(std::is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(BDataType); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple(K0 * number{}, number{}, K1), + make_tuple(K1, number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(K1)), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_kNLdsLayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(K0, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( + b_lds_block_desc_bk0_kNLdsLayer_n_bk1, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return b_lds_block_desc_n_k; + } + else // RowMajor B + { + constexpr auto N0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = Problem::kBlockSize / N0; + constexpr auto K0PerThreadWrite = K0 / KThreadWrite; + constexpr auto KThreadRead = 64 / WarpGemm::kN; + constexpr auto K0PerThreadRead = K0 / KThreadRead; + + constexpr auto kfold = + (K1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (K1 * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=kN0 + constexpr auto npair = (K1 * WarpGemm::kN * sizeof(BDataType) > 128) + ? 1 + : ((128 / (K1 * WarpGemm::kN * sizeof(BDataType))) > N0 + ? N0 + : 128 / (K1 * WarpGemm::kN * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + K1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_xor_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(K1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<1>{}, + sequence<2>{}, + sequence<0, 3>{}, + sequence<4, 5>{}, + sequence<6>{}, + sequence<7>{})); + + constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + K1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return b_lds_block_desc_n_k; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + { + constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * + MakeALdsBlockDescriptor().get_element_space_size(); + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() + { + constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) * + MakeBLdsBlockDescriptor().get_element_space_size(); + return smem_size_b; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_a = GetSmemSizeA(); + constexpr index_t smem_size_b = GetSmemSizeB(); + index_t smem_size = 0; + smem_size += smem_size_a + smem_size_b; + + return smem_size; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using WarpGemm = WarpGemmMfmaDispatcher; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = WarpGemm::kK; + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t M0 = MPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using WarpGemm = WarpGemmMfmaDispatcher; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = WarpGemm::kK; + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + + constexpr index_t N1 = BlockSize / get_warp_size(); + static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t N0 = NPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using AccDataType = float; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + return BlockGemmASmemBSmemCRegV1{}; + } +}; + +} // namespace ck_tile