mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 05:28:37 +00:00
Merge branch 'ck_migraphx_integration' into codegen-enable-hiprtc
This commit is contained in:
49
Jenkinsfile
vendored
49
Jenkinsfile
vendored
@@ -735,11 +735,11 @@ def process_results(Map conf=[:]){
|
||||
|
||||
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.2;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true
|
||||
0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true
|
||||
0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true;RUN_CODEGEN_TESTS=true
|
||||
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false
|
||||
0 13 * * * % BUILD_LEGACY_OS=true ''' : ""
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false
|
||||
0 13 * * * % BUILD_LEGACY_OS=true''' : ""
|
||||
|
||||
pipeline {
|
||||
agent none
|
||||
@@ -806,6 +806,10 @@ pipeline {
|
||||
name: "RUN_GROUPED_CONV_LARGE_CASES_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run the grouped conv large cases tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_CODEGEN_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run codegen tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_CK_TILE_FMHA_TESTS",
|
||||
defaultValue: false,
|
||||
@@ -926,7 +930,30 @@ pipeline {
|
||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
|
||||
make -j64 test_grouped_convnd_fwd_large_cases_xdl && \
|
||||
./bin/test_grouped_convnd_fwd_large_cases_xdl"""
|
||||
}
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Run Codegen Tests")
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Run Codegen Tests on gfx90a")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_CODEGEN_TESTS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a")}
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ CXX=/opt/rocm/llvm/bin/clang++ cmake ../codegen && \
|
||||
make -j64 check"""
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
@@ -951,7 +978,7 @@ pipeline {
|
||||
make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \
|
||||
cd ../ &&
|
||||
example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """
|
||||
}
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
@@ -970,7 +997,7 @@ pipeline {
|
||||
make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \
|
||||
cd ../ &&
|
||||
example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """
|
||||
}
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
@@ -995,7 +1022,7 @@ pipeline {
|
||||
make -j64 tile_example_gemm_basic && \
|
||||
cd ../ &&
|
||||
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """
|
||||
}
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
@@ -1014,7 +1041,7 @@ pipeline {
|
||||
make -j64 tile_example_gemm_basic && \
|
||||
cd ../ &&
|
||||
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """
|
||||
}
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
@@ -1040,7 +1067,7 @@ pipeline {
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " \
|
||||
-DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """
|
||||
execute_args = " "
|
||||
}
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name)
|
||||
cleanWs()
|
||||
@@ -1059,7 +1086,7 @@ pipeline {
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " \
|
||||
-DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """
|
||||
execute_args = " "
|
||||
}
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name)
|
||||
cleanWs()
|
||||
@@ -1140,7 +1167,7 @@ pipeline {
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \
|
||||
-D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
|
||||
}
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
project(composable_kernel_host)
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
|
||||
@@ -5,56 +8,51 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
||||
|
||||
add_compile_options(-std=c++17)
|
||||
find_package(hip)
|
||||
add_custom_target(codegen)
|
||||
find_package(ROCM)
|
||||
include(ROCMInstallTargets)
|
||||
include(ROCMTest)
|
||||
|
||||
# add include directories
|
||||
include_directories(BEFORE
|
||||
${PROJECT_BINARY_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/library/include
|
||||
${HIP_INCLUDE_DIRS}
|
||||
)
|
||||
rocm_setup_version(VERSION 1.0)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake)
|
||||
include(Embed)
|
||||
file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
|
||||
${CK_ROOT}/include/ck/*.hpp)
|
||||
#printouts fot debug purposes
|
||||
#message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
|
||||
#message(STATUS "RELATIVE: ${CK_ROOT}/include")
|
||||
${CK_ROOT}/include/ck/*.hpp)
|
||||
# printouts fot debug purposes
|
||||
# message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
|
||||
# message(STATUS "RELATIVE: ${CK_ROOT}/include")
|
||||
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include)
|
||||
|
||||
file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
|
||||
add_compile_options(-std=c++17)
|
||||
|
||||
##message(STATUS "SOURCE_FILES: ${SOURCES}")
|
||||
file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
|
||||
# TODO: Use object library
|
||||
add_library(ck_host STATIC ${SOURCES})
|
||||
target_link_libraries(ck_host PRIVATE ck_headers)
|
||||
|
||||
set_target_properties(ck_host PROPERTIES
|
||||
LINKER_LANGUAGE CXX
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(ck_host PROPERTIES
|
||||
LINKER_LANGUAGE CXX
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
target_include_directories(ck_host PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
# target_include_directories(ck_host PUBLIC
|
||||
# $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
# )
|
||||
|
||||
add_executable(ck-template-driver driver/main.cpp)
|
||||
target_link_libraries(ck-template-driver ck_host)
|
||||
|
||||
rocm_install(
|
||||
rocm_install_targets(
|
||||
TARGETS ck_host ck_headers
|
||||
EXPORT ck_hostTargets
|
||||
EXPORT ck_host_targets
|
||||
INCLUDE include
|
||||
PRIVATE
|
||||
)
|
||||
rocm_export_targets(
|
||||
EXPORT ck_host_targets
|
||||
NAMESPACE composable_kernel::
|
||||
)
|
||||
rocm_install(EXPORT ck_hostTargets
|
||||
FILE composable_kernelck_hostTargets.cmake
|
||||
NAMESPACE composable_kernel::
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel)
|
||||
rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||
|
||||
if(BUILD_TESTING)
|
||||
add_subdirectory(test)
|
||||
add_subdirectory(test)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -1,23 +1,25 @@
|
||||
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
|
||||
add_subdirectory(rtc)
|
||||
file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp)
|
||||
# do not build the tests when we build the library for various targets
|
||||
if(NOT GPU_ARCHS)
|
||||
foreach(TEST_SRC ${TEST_SRCS})
|
||||
set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP)
|
||||
get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE)
|
||||
add_executable(codegen_test_${BASE_NAME} ${TEST_SRC})
|
||||
if(CK_USE_ALTERNATIVE_PYTHON)
|
||||
target_link_options(codegen_test_${BASE_NAME} PRIVATE -lstdc++fs)
|
||||
endif()
|
||||
add_dependencies(codegen codegen_test_${BASE_NAME})
|
||||
add_dependencies(tests codegen_test_${BASE_NAME})
|
||||
add_dependencies(check codegen_test_${BASE_NAME})
|
||||
add_test(NAME codegen_test_${BASE_NAME} COMMAND codegen_test_${BASE_NAME})
|
||||
message("adding test codegen_test_${BASE_NAME}")
|
||||
target_link_libraries(codegen_test_${BASE_NAME} ck_rtc ck_host)
|
||||
target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/codegen/test/include)
|
||||
|
||||
# TODO: These tests need to be refactored to remove dependency on main ck
|
||||
# headers and device compilation.
|
||||
set(TESTS_REQUIRE_DEVICE_COMPILE
|
||||
grouped_conv_fwd_multiple_d_v1
|
||||
grouped_conv_fwd_multiple_d_v2
|
||||
grouped_conv_fwd_multiple_d_v3
|
||||
grouped_conv_fwd_multiple_d_v4
|
||||
)
|
||||
find_package(hip)
|
||||
|
||||
foreach(TEST_SRC ${TEST_SRCS})
|
||||
get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE)
|
||||
rocm_add_test_executable(codegen_test_${BASE_NAME} ${TEST_SRC})
|
||||
target_link_libraries(codegen_test_${BASE_NAME} ck_rtc ck_host)
|
||||
target_include_directories(codegen_test_${BASE_NAME} PUBLIC include)
|
||||
if(BASE_NAME IN_LIST TESTS_REQUIRE_DEVICE_COMPILE)
|
||||
target_link_libraries(codegen_test_${BASE_NAME} hip::device)
|
||||
target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/include)
|
||||
target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/library/include)
|
||||
endforeach()
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -1,36 +1,26 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/host/headers.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
#include <rtc/compile_kernel.hpp>
|
||||
#include <rtc/hip.hpp>
|
||||
#include <test.hpp>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <unordered_set>
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
const char* const ck_content_wrapper = R"__ck__(
|
||||
${content}
|
||||
)__ck__";
|
||||
|
||||
template <class P>
|
||||
inline std::string content_wrapper(P p)
|
||||
{
|
||||
return ck::host::InterpolateString(ck_content_wrapper,
|
||||
{{"content", std::string{p.data(), p.size()}}});
|
||||
}
|
||||
|
||||
inline std::vector<rtc::src_file> create_headers_for_test()
|
||||
{
|
||||
auto ck_headers = ck::host::GetHeaders();
|
||||
std::vector<rtc::src_file> result;
|
||||
std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [](auto& p) {
|
||||
return rtc::src_file{p.first, content_wrapper(p.second)};
|
||||
std::string content;
|
||||
content.reserve(p.second.size() + 1);
|
||||
content.push_back(' '); // We need a whitespace before the content for hipRTC to work
|
||||
content.append(p.second.data(), p.second.size());
|
||||
return rtc::src_file{p.first, std::move(content)};
|
||||
});
|
||||
return result;
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
find_package(hip)
|
||||
file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp)
|
||||
add_library(ck_rtc ${RTC_SOURCES})
|
||||
target_include_directories(ck_rtc PUBLIC include)
|
||||
target_link_libraries(ck_rtc PUBLIC hip::host)
|
||||
target_link_libraries(ck_rtc PUBLIC -lstdc++fs)
|
||||
|
||||
option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON)
|
||||
if(USE_HIPRTC_FOR_CODEGEN_TESTS)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL
|
||||
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL
|
||||
|
||||
#include <ck/filesystem.hpp>
|
||||
#include <rtc/kernel.hpp>
|
||||
#include <functional>
|
||||
#include <rtc/filesystem.hpp>
|
||||
#include <string>
|
||||
|
||||
namespace rtc {
|
||||
@@ -11,7 +10,7 @@ namespace rtc {
|
||||
struct src_file
|
||||
{
|
||||
src_file(std::filesystem::path p, std::string c) : path{std::move(p)}, content{std::move(c)} {}
|
||||
CK::fs::path path;
|
||||
fs::path path;
|
||||
std::string content;
|
||||
};
|
||||
|
||||
|
||||
60
codegen/test/rtc/include/rtc/filesystem.hpp
Normal file
60
codegen/test/rtc/include/rtc/filesystem.hpp
Normal file
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef GUARD_TEST_HOST_RTC_FILESYSTEM_HPP
|
||||
#define GUARD_TEST_HOST_RTC_FILESYSTEM_HPP
|
||||
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
|
||||
// clang-format off
|
||||
#if defined(CPPCHECK)
|
||||
#define RTC_HAS_FILESYSTEM 1
|
||||
#define RTC_HAS_FILESYSTEM_TS 1
|
||||
#elif defined(_WIN32)
|
||||
#if _MSC_VER >= 1920
|
||||
#define RTC_HAS_FILESYSTEM 1
|
||||
#define RTC_HAS_FILESYSTEM_TS 0
|
||||
#elif _MSC_VER >= 1900
|
||||
#define RTC_HAS_FILESYSTEM 0
|
||||
#define RTC_HAS_FILESYSTEM_TS 1
|
||||
#else
|
||||
#define RTC_HAS_FILESYSTEM 0
|
||||
#define RTC_HAS_FILESYSTEM_TS 0
|
||||
#endif
|
||||
#elif defined(__has_include)
|
||||
#if __has_include(<filesystem>) && __cplusplus >= 201703L
|
||||
#define RTC_HAS_FILESYSTEM 1
|
||||
#else
|
||||
#define RTC_HAS_FILESYSTEM 0
|
||||
#endif
|
||||
#if __has_include(<experimental/filesystem>) && __cplusplus >= 201103L
|
||||
#define RTC_HAS_FILESYSTEM_TS 1
|
||||
#else
|
||||
#define RTC_HAS_FILESYSTEM_TS 0
|
||||
#endif
|
||||
#else
|
||||
#define RTC_HAS_FILESYSTEM 0
|
||||
#define RTC_HAS_FILESYSTEM_TS 0
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
#if RTC_HAS_FILESYSTEM
|
||||
#include <filesystem>
|
||||
#elif RTC_HAS_FILESYSTEM_TS
|
||||
#include <experimental/filesystem>
|
||||
#else
|
||||
#error "No filesystem include available"
|
||||
#endif
|
||||
|
||||
namespace rtc {
|
||||
|
||||
#if RTC_HAS_FILESYSTEM
|
||||
namespace fs = ::std::filesystem;
|
||||
#elif RTC_HAS_FILESYSTEM_TS
|
||||
namespace fs = ::std::experimental::filesystem;
|
||||
#endif
|
||||
|
||||
} // namespace rtc
|
||||
|
||||
#endif // GUARD_RTC_FILESYSTEM_HPP_
|
||||
@@ -2,13 +2,13 @@
|
||||
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR
|
||||
|
||||
#include <string>
|
||||
#include <ck/filesystem.hpp>
|
||||
#include <rtc/filesystem.hpp>
|
||||
|
||||
namespace rtc {
|
||||
|
||||
struct tmp_dir
|
||||
{
|
||||
CK::fs::path path;
|
||||
fs::path path;
|
||||
tmp_dir(const std::string& prefix = "");
|
||||
|
||||
void execute(const std::string& cmd) const;
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
#include <rtc/compile_kernel.hpp>
|
||||
#include <rtc/hip.hpp>
|
||||
#include <rtc/compile_kernel.hpp>
|
||||
#ifdef HIPRTC_FOR_CODEGEN_TESTS
|
||||
#include <hip/hiprtc.h>
|
||||
#include <rtc/manage_ptr.hpp>
|
||||
#endif
|
||||
#include <rtc/tmp_dir.hpp>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <deque>
|
||||
#include <fstream>
|
||||
@@ -96,9 +97,9 @@ kernel clang_compile_kernel(const std::vector<src_file>& srcs, compile_options o
|
||||
|
||||
for(const auto& src : srcs)
|
||||
{
|
||||
CK::fs::path full_path = td.path / src.path;
|
||||
CK::fs::path parent_path = full_path.parent_path();
|
||||
CK::fs::create_directories(parent_path);
|
||||
fs::path full_path = td.path / src.path;
|
||||
fs::path parent_path = full_path.parent_path();
|
||||
fs::create_directories(parent_path);
|
||||
write_string(full_path.string(), src.content);
|
||||
if(src.path.extension().string() == ".cpp")
|
||||
{
|
||||
@@ -112,7 +113,7 @@ kernel clang_compile_kernel(const std::vector<src_file>& srcs, compile_options o
|
||||
td.execute(compiler() + options.flags);
|
||||
|
||||
auto out_path = td.path / out;
|
||||
if(not CK::fs::exists(out_path))
|
||||
if(not fs::exists(out_path))
|
||||
throw std::runtime_error("Output file missing: " + out);
|
||||
|
||||
auto obj = read_buffer(out_path.string());
|
||||
@@ -204,7 +205,7 @@ struct hiprtc_program
|
||||
}
|
||||
else
|
||||
{
|
||||
headers.push_back(std::string(src.content.begin(), src.content.end()));
|
||||
headers.push_back(std::move(src.content));
|
||||
include_names.push_back(std::move(src.path));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,10 +31,10 @@ std::string unique_string(const std::string& prefix)
|
||||
}
|
||||
|
||||
tmp_dir::tmp_dir(const std::string& prefix)
|
||||
: path(CK::fs::temp_directory_path() /
|
||||
: path(fs::temp_directory_path() /
|
||||
unique_string(prefix.empty() ? "ck-rtc" : "ck-rtc-" + prefix))
|
||||
{
|
||||
CK::fs::create_directories(this->path);
|
||||
fs::create_directories(this->path);
|
||||
}
|
||||
|
||||
void tmp_dir::execute(const std::string& cmd) const
|
||||
@@ -43,6 +43,6 @@ void tmp_dir::execute(const std::string& cmd) const
|
||||
std::system(s.c_str());
|
||||
}
|
||||
|
||||
tmp_dir::~tmp_dir() { CK::fs::remove_all(this->path); }
|
||||
tmp_dir::~tmp_dir() { fs::remove_all(this->path); }
|
||||
|
||||
} // namespace rtc
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
@@ -282,7 +281,11 @@ int main(int argc, char* argv[])
|
||||
using CodegenPipelineProblem = ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
using CodegenGemmPolicy = ck_tile::
|
||||
UniversalGemmPipelineAgBgCrPolicy<matrix_a_layout, matrix_b_layout, matrix_c_layout>;
|
||||
|
||||
using CodegenGemmPipeline =
|
||||
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
|
||||
|
||||
invoke_gemm<ck_tile::half_t,
|
||||
matrix_a_layout,
|
||||
|
||||
@@ -615,96 +615,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return true;
|
||||
}
|
||||
|
||||
static constexpr bool
|
||||
IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
|
||||
{
|
||||
// check vector load/store
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row>)
|
||||
{
|
||||
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col>)
|
||||
{
|
||||
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of B
|
||||
if constexpr(is_same_v<BLayout, Row>)
|
||||
{
|
||||
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Col>)
|
||||
{
|
||||
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of B1
|
||||
if constexpr(is_same_v<B1Layout, Row>)
|
||||
{
|
||||
if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<B1Layout, Col>)
|
||||
{
|
||||
if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of C
|
||||
if constexpr(is_same_v<CLayout, Row>)
|
||||
{
|
||||
if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<CLayout, Col>)
|
||||
{
|
||||
if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
@@ -861,268 +771,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return str.str();
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class ADesc, class BDesc, class B1Desc, class CDesc>
|
||||
struct Descriptor
|
||||
{
|
||||
template <class AGridDescriptor>
|
||||
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc)
|
||||
{
|
||||
const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc);
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <class BGridDescriptor>
|
||||
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc)
|
||||
{
|
||||
const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc);
|
||||
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <class B1GridDescriptor>
|
||||
static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc)
|
||||
{
|
||||
const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc);
|
||||
|
||||
const auto N = b1_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b1_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto B1K0 = K / B1K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b1_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <class CGridDescriptor>
|
||||
static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc)
|
||||
{
|
||||
return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc);
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>;
|
||||
using B1GridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
B1GridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
Gemm1NPerBlock,
|
||||
Gemm1KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
B1K1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
Gemm1NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
true,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
true,
|
||||
BBlockLdsExtraN,
|
||||
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
B1BlockTransferThreadClusterArrangeOrder,
|
||||
B1BlockTransferSrcAccessOrder,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
B1BlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
matrix_padder.PadN,
|
||||
MaskOutUpperTriangle>;
|
||||
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
|
||||
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
|
||||
CGridDesc_M_N c_grid_desc_m_n;
|
||||
C0MatrixMask c0_matrix_mask;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_descriptor_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op;
|
||||
BElementwiseOperation b_element_op;
|
||||
B1ElementwiseOperation b1_element_op;
|
||||
CElementwiseOperation c_element_op;
|
||||
|
||||
bool has_main_k_block_loop = true;
|
||||
bool is_valid = false;
|
||||
|
||||
constexpr Descriptor(ADesc a,
|
||||
BDesc b,
|
||||
B1Desc b1,
|
||||
CDesc c,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
B1ElementwiseOperation b1_element_op_,
|
||||
CElementwiseOperation c_element_op_)
|
||||
: a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)},
|
||||
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
|
||||
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
|
||||
c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
|
||||
block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)},
|
||||
c_grid_descriptor_mblock_mperblock_nblock_nperblock{
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n)},
|
||||
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
|
||||
c0_matrix_mask{c.GetLength(I1)},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
b1_element_op{b1_element_op_},
|
||||
c_element_op{c_element_op_},
|
||||
is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_m_n,
|
||||
block_2_ctile_map) and
|
||||
IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
|
||||
b_grid_desc_bk0_n_bk1.GetLength(I1),
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I0) *
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I2),
|
||||
b1_grid_desc_bk0_n_bk1.GetLength(I1))}
|
||||
{
|
||||
}
|
||||
|
||||
constexpr bool IsValid() const { return is_valid; }
|
||||
};
|
||||
|
||||
template <class ADesc, class BDesc, class B1Desc, class CDesc>
|
||||
static constexpr auto
|
||||
make_descriptor(ADesc a,
|
||||
BDesc b,
|
||||
B1Desc b1,
|
||||
CDesc c,
|
||||
AElementwiseOperation a_element_op = AElementwiseOperation{},
|
||||
BElementwiseOperation b_element_op = BElementwiseOperation{},
|
||||
B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
|
||||
CElementwiseOperation c_element_op = CElementwiseOperation{})
|
||||
{
|
||||
return Descriptor<ADesc, BDesc, B1Desc, CDesc>(
|
||||
a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op);
|
||||
}
|
||||
|
||||
template <class Desc>
|
||||
__device__ static void Run(const Desc& desc,
|
||||
const float scale,
|
||||
const ADataType* __restrict__ p_a_grid,
|
||||
const ADataType* __restrict__ p_b_grid,
|
||||
const ADataType* __restrict__ p_b1_grid,
|
||||
CDataType* __restrict__ p_c_grid)
|
||||
{
|
||||
#ifndef __HIPCC_RTC__
|
||||
assert(desc.is_valid);
|
||||
#endif
|
||||
__shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
AccElementwiseOperation acc_element_op{scale};
|
||||
|
||||
if(desc.has_main_k_block_loop)
|
||||
{
|
||||
Desc::GridwiseGemm::template Run<true>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_b1_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
acc_element_op,
|
||||
desc.b1_element_op,
|
||||
desc.c_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.b1_grid_desc_bk0_n_bk1,
|
||||
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_ctile_map,
|
||||
desc.c0_matrix_mask);
|
||||
}
|
||||
else
|
||||
{
|
||||
Desc::GridwiseGemm::template Run<false>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_b1_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
acc_element_op,
|
||||
desc.b1_element_op,
|
||||
desc.c_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.b1_grid_desc_bk0_n_bk1,
|
||||
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_ctile_map,
|
||||
desc.c0_matrix_mask);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/enable_if.hpp"
|
||||
#include "ck/utility/statically_indexed_array.hpp"
|
||||
|
||||
#ifdef __HIPCC_RTC__
|
||||
@@ -204,7 +205,7 @@ struct scalar_type<bool>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 1, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using type = d1_t;
|
||||
@@ -240,7 +241,7 @@ struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
|
||||
|
||||
__device__ int static err = 0;
|
||||
template <typename T>
|
||||
struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 2, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -300,7 +301,7 @@ struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 4, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -370,7 +371,7 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 8, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -452,7 +453,7 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 16, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -546,7 +547,7 @@ struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 32, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -650,7 +651,7 @@ struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 64, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -766,7 +767,7 @@ struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 128, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -892,7 +893,7 @@ struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
|
||||
struct vector_type<T, 256, typename ck::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -1042,7 +1043,7 @@ struct non_native_vector_base
|
||||
|
||||
// non-native vector_type implementation
|
||||
template <typename T>
|
||||
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
struct vector_type<T, 1, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using type = d1_t;
|
||||
@@ -1077,7 +1078,7 @@ struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
struct vector_type<T, 2, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
@@ -1137,7 +1138,7 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
struct vector_type<T, 4, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
@@ -1207,7 +1208,7 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
struct vector_type<T, 8, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
@@ -1289,7 +1290,7 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
struct vector_type<T, 16, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
@@ -1383,7 +1384,7 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
struct vector_type<T, 32, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
@@ -1487,7 +1488,7 @@ struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
struct vector_type<T, 64, typename ck::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "data_type.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "number.hpp"
|
||||
#include "type.hpp"
|
||||
|
||||
@@ -157,8 +157,11 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// workaround for ROCm 6.2 and later
|
||||
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
|
||||
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133
|
||||
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133) || \
|
||||
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 3 && HIP_VERSION_PATCH >= 42131) || \
|
||||
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR > 3)
|
||||
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
|
||||
#else
|
||||
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
|
||||
|
||||
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
k_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
|
||||
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto v_lds_write_window =
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK2>{}),
|
||||
v_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// KT, Reg ->LDS ->Reg
|
||||
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_k_lds_write_window = make_tile_window(
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
auto v_reg_tensor = load_tile(v_lds_read_window);
|
||||
block_sync_lds();
|
||||
//---------------------------- Loop Load in ----------------------------//
|
||||
// Q: HBM ->Reg ->LDS
|
||||
@@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds_window.get_bottom_tensor_view(),
|
||||
@@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_q_lds_write_window = make_tile_window(
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto do_lds_read_window =
|
||||
make_tile_window(do_lds_window.get_bottom_tensor_view(),
|
||||
@@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_do_lds_write_window = make_tile_window(
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
index_t seqlen_q_step = seqlen_q_start;
|
||||
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
|
||||
static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
|
||||
static_assert(kM0 == kK1, "kM0 should equal to kK1");
|
||||
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
|
||||
static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
|
||||
static_assert(kM0 == kK3, "kM0 should equal to kK3");
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
|
||||
|
||||
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
k_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
|
||||
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto v_lds_write_window =
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0>{}, number<kK2>{}),
|
||||
v_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// KT, Reg ->LDS ->Reg
|
||||
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_k_lds_write_window = make_tile_window(
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
auto v_reg_tensor = load_tile(v_lds_read_window);
|
||||
//---------------------------- Loop Load in ----------------------------//
|
||||
// Q: HBM ->Reg ->LDS
|
||||
auto q_dram_window =
|
||||
@@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds_window.get_bottom_tensor_view(),
|
||||
@@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_q_lds_write_window = make_tile_window(
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto do_lds_read_window =
|
||||
make_tile_window(do_lds_window.get_bottom_tensor_view(),
|
||||
@@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_do_lds_write_window = make_tile_window(
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
|
||||
@@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
index_t seqlen_q_step = seqlen_q_start;
|
||||
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
|
||||
static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
|
||||
static_assert(kM0 == kK1, "kM0 should equal to kK1");
|
||||
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
|
||||
static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
|
||||
static_assert(kM0 == kK3, "kM0 should equal to kK3");
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(QDataType);
|
||||
|
||||
@@ -215,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(KDataType);
|
||||
|
||||
@@ -234,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType);
|
||||
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
@@ -254,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType);
|
||||
|
||||
@@ -315,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
@@ -327,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
return total_pixels / GetAlignmentK<Problem>();
|
||||
@@ -338,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
@@ -376,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentK<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -399,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -422,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentQ<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -445,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -816,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto k_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
|
||||
|
||||
return k_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor()
|
||||
{
|
||||
@@ -865,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
@@ -890,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t kVPack = GetSmemKPackV<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto v_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
|
||||
|
||||
return v_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor()
|
||||
{
|
||||
@@ -940,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
@@ -966,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentK<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -1048,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
|
||||
@@ -1092,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentQ<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -1255,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
// Hold full block data
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
||||
|
||||
@@ -1299,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -1859,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
|
||||
static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
|
||||
static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
|
||||
|
||||
static constexpr index_t WarpGemmM =
|
||||
@@ -1873,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
|
||||
// Compute
|
||||
static constexpr index_t Gemm0MFMA =
|
||||
kM0 * kN0 * kQKHeaddim /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
kM0 * kN0 * kK0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm1MFMA =
|
||||
kM0 * kN0 * kVHeaddim /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm2MFMA =
|
||||
kN0 * kVHeaddim * kM0 /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm2MFMA =
|
||||
kM0 * kN0 * kK2 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm3MFMA =
|
||||
kN0 * kQKHeaddim * kM0 /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
@@ -1903,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
|
||||
static constexpr index_t SGradT_LDS_READ_P1 =
|
||||
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
|
||||
static constexpr index_t Q_LDS_READ =
|
||||
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
|
||||
static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ<Problem>();
|
||||
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
|
||||
static constexpr index_t SGradT_LDS_READ_P2 =
|
||||
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
|
||||
static constexpr index_t OGrad_LDS_READ =
|
||||
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
|
||||
kM0 * kK2 / kBlockSize / GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
|
||||
|
||||
// LDS Write
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
|
||||
|
||||
@@ -0,0 +1,424 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// UniversalGemm Policy
|
||||
template <typename LayoutA_, typename LayoutB_, typename LayoutC_>
|
||||
struct UniversalGemmPipelineAgBgCrPolicy
|
||||
{
|
||||
using LayoutA = remove_cvref_t<LayoutA_>;
|
||||
using LayoutB = remove_cvref_t<LayoutB_>;
|
||||
using LayoutC = remove_cvref_t<LayoutC_>;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
static constexpr bool TransposeC = true;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
Problem::BlockGemmShape::WarpTile::at(I0),
|
||||
Problem::BlockGemmShape::WarpTile::at(I1),
|
||||
Problem::BlockGemmShape::WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t K1 = WarpGemm::kK;
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
|
||||
if constexpr(std::is_same<tensor_layout::gemm::RowMajor, LayoutA>::value)
|
||||
{
|
||||
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1
|
||||
? 1
|
||||
: 32 * 4 / KPerBlock / sizeof(ADataType);
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K0 * number<MLdsLayer>{}, number<MPerBlock / MLdsLayer>{}, K1),
|
||||
make_tuple(K1, number<KPerBlock * MLdsLayer>{}, I1));
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
|
||||
number<K0 * MLdsLayer>{})),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_ak0_kMLdsLayer_m_ak1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, number<MLdsLayer>{})),
|
||||
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
|
||||
a_lds_block_desc_ak0_kMLdsLayer_m_ak1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{}))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return a_lds_block_desc_m_k;
|
||||
}
|
||||
else // ColumnMajor A
|
||||
{
|
||||
// kfold and mpair dimension is not always required.
|
||||
// more dimension in merge_transform increase the difficulty of generating immarg offset
|
||||
// for compiler.
|
||||
constexpr auto M0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I0);
|
||||
constexpr auto M1 = MPerBlock / M0;
|
||||
|
||||
constexpr auto KThreadWrite = Problem::kBlockSize / M0;
|
||||
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
|
||||
constexpr auto KThreadRead = 64 / WarpGemm::kM;
|
||||
constexpr auto K0PerThreadRead = K0 / KThreadRead;
|
||||
|
||||
constexpr auto kfold =
|
||||
(K1 * M0 * sizeof(ADataType) > 128) ? 1 : 128 / (K1 * M0 * sizeof(ADataType));
|
||||
constexpr auto KThreadReadPerm =
|
||||
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
|
||||
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
|
||||
: KThreadRead;
|
||||
|
||||
// 1<=mpair<=kN0
|
||||
constexpr auto mpair = (K1 * WarpGemm::kM * sizeof(ADataType) > 128)
|
||||
? 1
|
||||
: ((128 / (K1 * WarpGemm::kM * sizeof(ADataType))) > M0
|
||||
? M0
|
||||
: 128 / (K1 * WarpGemm::kM * sizeof(ADataType)));
|
||||
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
number<KThreadReadPerm * M1>{},
|
||||
number<kfold * M0 / mpair>{},
|
||||
number<mpair>{},
|
||||
K1));
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_tuple(number<KThreadReadPerm * M1>{}, number<kfold * M0 / mpair>{})),
|
||||
make_pass_through_transform(number<mpair>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<M1>{})),
|
||||
make_unmerge_transform(make_tuple(number<kfold>{}, number<M0 / mpair>{})),
|
||||
make_pass_through_transform(number<mpair>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<0, 3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{},
|
||||
sequence<7>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
|
||||
a_lds_block_desc_unmerged,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KThreadReadPerm>{},
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<kfold>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
K1)),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<M0 / mpair>{}, number<mpair>{}, number<M1>{}))),
|
||||
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return a_lds_block_desc_m_k;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
Problem::BlockGemmShape::WarpTile::at(I0),
|
||||
Problem::BlockGemmShape::WarpTile::at(I1),
|
||||
Problem::BlockGemmShape::WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = WarpGemm::kK;
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
|
||||
if constexpr(std::is_same<tensor_layout::gemm::ColumnMajor, LayoutB>::value)
|
||||
{
|
||||
// NLdsLayer * K0 as logical Bank
|
||||
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1
|
||||
? 1
|
||||
: 32 * 4 / KPerBlock / sizeof(BDataType);
|
||||
;
|
||||
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, K1),
|
||||
make_tuple(K1, number<KPerBlock * NLdsLayer>{}, I1));
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
|
||||
number<K0 * NLdsLayer>{})),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_bk0_kNLdsLayer_n_bk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, number<NLdsLayer>{})),
|
||||
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
|
||||
b_lds_block_desc_bk0_kNLdsLayer_n_bk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{}))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return b_lds_block_desc_n_k;
|
||||
}
|
||||
else // RowMajor B
|
||||
{
|
||||
constexpr auto N0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I1);
|
||||
constexpr auto N1 = NPerBlock / N0;
|
||||
|
||||
constexpr auto KThreadWrite = Problem::kBlockSize / N0;
|
||||
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
|
||||
constexpr auto KThreadRead = 64 / WarpGemm::kN;
|
||||
constexpr auto K0PerThreadRead = K0 / KThreadRead;
|
||||
|
||||
constexpr auto kfold =
|
||||
(K1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (K1 * N0 * sizeof(BDataType));
|
||||
constexpr auto KThreadReadPerm =
|
||||
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
|
||||
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
|
||||
: KThreadRead;
|
||||
|
||||
// 1<=npair<=kN0
|
||||
constexpr auto npair = (K1 * WarpGemm::kN * sizeof(BDataType) > 128)
|
||||
? 1
|
||||
: ((128 / (K1 * WarpGemm::kN * sizeof(BDataType))) > N0
|
||||
? N0
|
||||
: 128 / (K1 * WarpGemm::kN * sizeof(BDataType)));
|
||||
|
||||
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
number<KThreadReadPerm * N1>{},
|
||||
number<kfold * N0 / npair>{},
|
||||
number<npair>{},
|
||||
K1));
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_tuple(number<KThreadReadPerm * N1>{}, number<kfold * N0 / npair>{})),
|
||||
make_pass_through_transform(number<npair>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
|
||||
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
|
||||
make_pass_through_transform(number<npair>{}),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<0, 3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{},
|
||||
sequence<7>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
|
||||
b_lds_block_desc_unmerged,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KThreadReadPerm>{},
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<kfold>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
K1)),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
|
||||
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return b_lds_block_desc_n_k;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
|
||||
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
|
||||
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
|
||||
index_t smem_size = 0;
|
||||
smem_size += smem_size_a + smem_size_b;
|
||||
|
||||
return smem_size;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
Problem::BlockGemmShape::WarpTile::at(I0),
|
||||
Problem::BlockGemmShape::WarpTile::at(I1),
|
||||
Problem::BlockGemmShape::WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = WarpGemm::kK;
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
Problem::BlockGemmShape::WarpTile::at(I0),
|
||||
Problem::BlockGemmShape::WarpTile::at(I1),
|
||||
Problem::BlockGemmShape::WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = WarpGemm::kK;
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t N1 = BlockSize / get_warp_size();
|
||||
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
AccDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user