From 871810b885d211c6bd1c8c898334dbcb9f0196a0 Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Tue, 5 Mar 2024 19:08:43 -0600 Subject: [PATCH] Add host lib (#1134) * Format * Format * Format * Remove const * Use the right template * Format * Format * add row/col instances * Add missing file * fixed * Format * Updates * Format * fixed rrr layout * Format * Update test and embed modules * Restore older version * Update year * Set -fPIC * Format * Use double for isnan * rename host folder to codegen + minor fix * add codegen CI test * add option to build components without building CK * fix the groovy syntax * fix typo * use the correct function for the codegen stage --------- Co-authored-by: Jing Zhang Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin [ROCm/composable_kernel commit: 8eff4d62b669df7c34e1490d38520537c0178e2f] --- Jenkinsfile | 51 +- cmake/Embed.cmake | 238 +++++ codegen/CMakeLists.txt | 49 + codegen/driver/main.cpp | 71 ++ .../ck/host/device_gemm_multiple_d.hpp | 42 + .../host/device_gemm_multiple_d/operation.hpp | 42 + .../host/device_gemm_multiple_d/problem.hpp | 39 + codegen/include/ck/host/headers.hpp | 18 + codegen/include/ck/host/operation/gemm.hpp | 49 + codegen/include/ck/host/stringutils.hpp | 104 +++ codegen/include/ck/host/types.hpp | 78 ++ codegen/include/ck/host/utils.hpp | 17 + codegen/src/device_gemm_multiple_d.cpp | 33 + ...gemm_multiple_d_operation_xdl_cshuffle.cpp | 295 ++++++ codegen/src/headers.cpp | 17 + codegen/src/types.cpp | 63 ++ codegen/src/utils.cpp | 21 + codegen/test/CMakeLists.txt | 11 + codegen/test/gemm_multiple_d.cpp | 185 ++++ codegen/test/include/test.hpp | 848 ++++++++++++++++++ codegen/test/rtc/CMakeLists.txt | 6 + .../test/rtc/include/rtc/compile_kernel.hpp | 27 + codegen/test/rtc/include/rtc/hip.hpp | 78 ++ codegen/test/rtc/include/rtc/kernel.hpp | 62 ++ codegen/test/rtc/include/rtc/manage_ptr.hpp | 55 ++ codegen/test/rtc/include/rtc/tmp_dir.hpp | 24 + codegen/test/rtc/src/compile_kernel.cpp | 95 ++ codegen/test/rtc/src/hip.cpp | 102 +++ codegen/test/rtc/src/kernel.cpp | 121 +++ codegen/test/rtc/src/tmp_dir.cpp | 48 + .../device_gemm_multiple_d_xdl_cshuffle.hpp | 335 +++++-- .../gpu/grid/block_to_ctile_map.hpp | 53 +- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 10 +- 33 files changed, 3170 insertions(+), 117 deletions(-) create mode 100644 cmake/Embed.cmake create mode 100644 codegen/CMakeLists.txt create mode 100644 codegen/driver/main.cpp create mode 100644 codegen/include/ck/host/device_gemm_multiple_d.hpp create mode 100644 codegen/include/ck/host/device_gemm_multiple_d/operation.hpp create mode 100644 codegen/include/ck/host/device_gemm_multiple_d/problem.hpp create mode 100644 codegen/include/ck/host/headers.hpp create mode 100644 codegen/include/ck/host/operation/gemm.hpp create mode 100644 codegen/include/ck/host/stringutils.hpp create mode 100644 codegen/include/ck/host/types.hpp create mode 100644 codegen/include/ck/host/utils.hpp create mode 100644 codegen/src/device_gemm_multiple_d.cpp create mode 100644 codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp create mode 100644 codegen/src/headers.cpp create mode 100644 codegen/src/types.cpp create mode 100644 codegen/src/utils.cpp create mode 100644 codegen/test/CMakeLists.txt create mode 100644 codegen/test/gemm_multiple_d.cpp create mode 100644 codegen/test/include/test.hpp create mode 100644 codegen/test/rtc/CMakeLists.txt create mode 100644 codegen/test/rtc/include/rtc/compile_kernel.hpp create mode 100644 codegen/test/rtc/include/rtc/hip.hpp create mode 100644 codegen/test/rtc/include/rtc/kernel.hpp create mode 100644 codegen/test/rtc/include/rtc/manage_ptr.hpp create mode 100644 codegen/test/rtc/include/rtc/tmp_dir.hpp create mode 100644 codegen/test/rtc/src/compile_kernel.cpp create mode 100644 codegen/test/rtc/src/hip.cpp create mode 100644 codegen/test/rtc/src/kernel.cpp create mode 100644 codegen/test/rtc/src/tmp_dir.cpp diff --git a/Jenkinsfile b/Jenkinsfile index 3cac20fd34..abecb76408 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -264,18 +264,24 @@ def cmake_build(Map conf=[:]){ """) sh cmd3 } - - def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") // reduce parallelism when compiling, clang uses too much memory def nt = nthreads() - def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}") + def cmd def execute_cmd = conf.get("execute_cmd", "") - - def cmd = conf.get("cmd", """ + if(!setup_args.contains("NO_CK_BUILD")){ + def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") + def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}") + cmd = conf.get("cmd", """ ${setup_cmd} ${build_cmd} ${execute_cmd} """) + } + else{ + cmd = conf.get("cmd", """ + ${execute_cmd} + """) + } echo cmd @@ -667,7 +673,7 @@ pipeline { string( name: 'USE_CUSTOM_DOCKER', defaultValue: '', - description: 'If you want to use a custom docker image, please scecify it here (default: OFF).') + description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( name: 'ROCMVERSION', defaultValue: '6.0', @@ -712,6 +718,10 @@ pipeline { name: "RUN_PERFORMANCE_TESTS", defaultValue: false, description: "Run the performance tests (default: OFF)") + booleanParam( + name: "RUN_CODEGEN_TESTS", + defaultValue: true, + description: "Run the codegen tests (default: ON)") } environment{ dbuser = "${dbuser}" @@ -790,7 +800,34 @@ pipeline { } } } - + stage("Run Codegen Tests") + { + parallel + { + stage("Run Codegen Tests on MI100/MI200") + { + when { + beforeAgent true + expression { params.RUN_CODEGEN_TESTS.toBoolean() } + } + options { retry(2) } + agent{ label rocmnode("gfx908 || gfx90a")} + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ cd ../codegen && rm -rf build && mkdir build && cd build && \ + cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx908;gfx90a" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j check""" + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } stage("Build CK and run Tests") { parallel diff --git a/cmake/Embed.cmake b/cmake/Embed.cmake new file mode 100644 index 0000000000..4bc638b446 --- /dev/null +++ b/cmake/Embed.cmake @@ -0,0 +1,238 @@ +##################################################################################### +# The MIT License (MIT) +# +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +##################################################################################### + +if(WIN32) + set(EMBED_USE RC CACHE STRING "Use RC or CArrays to embed data files") + set_property(CACHE EMBED_USE PROPERTY STRINGS "RC;CArrays") +else() + if(BUILD_SHARED_LIBS) + set(EMBED_USE LD CACHE STRING "Use LD or CArrays to embed data files") + else() + set(EMBED_USE CArrays CACHE STRING "Use LD or CArrays to embed data files") + endif() + set_property(CACHE EMBED_USE PROPERTY STRINGS "LD;CArrays") +endif() + +if(EMBED_USE STREQUAL "LD") + find_program(EMBED_LD ld REQUIRED) + find_program(EMBED_OBJCOPY objcopy REQUIRED) +endif() + +function(embed_wrap_string) + set(options) + set(oneValueArgs VARIABLE AT_COLUMN) + set(multiValueArgs) + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + string(LENGTH ${${PARSE_VARIABLE}} string_length) + math(EXPR offset "0") + + while(string_length GREATER 0) + + if(string_length GREATER ${PARSE_AT_COLUMN}) + math(EXPR length "${PARSE_AT_COLUMN}") + else() + math(EXPR length "${string_length}") + endif() + + string(SUBSTRING ${${PARSE_VARIABLE}} ${offset} ${length} line) + set(lines "${lines}\n${line}") + + math(EXPR string_length "${string_length} - ${length}") + math(EXPR offset "${offset} + ${length}") + endwhile() + + set(${PARSE_VARIABLE} "${lines}" PARENT_SCOPE) +endfunction() + +function(generate_embed_source EMBED_NAME EMBED_DIR BASE_DIRECTORY) + set(options) + set(oneValueArgs) + set(multiValueArgs SYMBOLS FILES) + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(RESOURCE_ID 100) + + list(LENGTH PARSE_SYMBOLS SYMBOLS_LEN) + list(LENGTH PARSE_FILES FILES_LEN) + if(NOT ${SYMBOLS_LEN} EQUAL ${FILES_LEN}) + message(FATAL_ERROR "Symbols and objects dont match: ${SYMBOLS_LEN} != ${FILES_LEN}") + endif() + math(EXPR LEN "${SYMBOLS_LEN} - 1") + + foreach(idx RANGE ${LEN}) + list(GET PARSE_SYMBOLS ${idx} SYMBOL) + list(GET PARSE_FILES ${idx} FILE) + file(RELATIVE_PATH BASE_NAME "${BASE_DIRECTORY}" ${FILE}) + if(EMBED_USE STREQUAL "RC") + string(TOUPPER "${SYMBOL}" SYMBOL) + string(APPEND FILE_IDS "#define IDR_${SYMBOL} ${RESOURCE_ID}\n") + file(TO_NATIVE_PATH "${FILE}" NATIVE_FILE) + string(REPLACE "\\" "\\\\" NATIVE_FILE "${NATIVE_FILE}") + string(APPEND RC_FILE_MAPPING "IDR_${SYMBOL} TEXTFILE \"${NATIVE_FILE}\"\n") + string(APPEND INIT_KERNELS "\n {\"${BASE_NAME}\", resource::read(IDR_${SYMBOL})},") + math(EXPR RESOURCE_ID "${RESOURCE_ID} + 1" OUTPUT_FORMAT DECIMAL) + else() + set(START_SYMBOL "_binary_${SYMBOL}_start") + set(LENGTH_SYMBOL "_binary_${SYMBOL}_length") + if(EMBED_USE STREQUAL "LD") + string(APPEND EXTERNS " +extern const char ${START_SYMBOL}[]; +extern const size_t _binary_${SYMBOL}_size; +const auto ${LENGTH_SYMBOL} = reinterpret_cast(&_binary_${SYMBOL}_size); +") + else() + string(APPEND EXTERNS " +extern const char ${START_SYMBOL}[]; +extern const size_t ${LENGTH_SYMBOL}; +") + endif() + string(APPEND INIT_KERNELS " + { \"${BASE_NAME}\", { ${START_SYMBOL}, ${LENGTH_SYMBOL}} },") + endif() + endforeach() + if(EMBED_USE STREQUAL "RC") + file(WRITE "${EMBED_DIR}/include/resource.h" " +#define TEXTFILE 256 + +${FILE_IDS} +") + file(WRITE "${EMBED_DIR}/resource.rc" " +#include \"resource.h\" + +${RC_FILE_MAPPING} +") + set(EXTERNS " +#include +#include \"resource.h\" + +namespace resource { +std::string_view read(int id) +{ + HMODULE handle = GetModuleHandle(nullptr); + HRSRC rc = FindResource(handle, MAKEINTRESOURCE(id), MAKEINTRESOURCE(TEXTFILE)); + HGLOBAL data = LoadResource(handle, rc); + return {static_cast(LockResource(data)), SizeofResource(handle, rc)}; +} +} +") + set(EMBED_FILES ${EMBED_DIR}/include/resource.h ${EMBED_DIR}/resource.rc) + endif() + file(WRITE "${EMBED_DIR}/include/${EMBED_NAME}.hpp" " +#include +#include +#include +std::unordered_map ${EMBED_NAME}(); +") + + file(WRITE "${EMBED_DIR}/${EMBED_NAME}.cpp" " +#include <${EMBED_NAME}.hpp> +${EXTERNS} +std::unordered_map ${EMBED_NAME}() +{ + static std::unordered_map result = {${INIT_KERNELS} + }; + return result; +} +") + list(APPEND EMBED_FILES ${EMBED_DIR}/${EMBED_NAME}.cpp ${EMBED_DIR}/include/${EMBED_NAME}.hpp) + set(EMBED_FILES ${EMBED_FILES} PARENT_SCOPE) +endfunction() + +function(embed_file FILE BASE_DIRECTORY) + message(STATUS " ${FILE}") + file(RELATIVE_PATH REL_FILE "${BASE_DIRECTORY}" ${FILE}) + string(MAKE_C_IDENTIFIER "${REL_FILE}" OUTPUT_SYMBOL) + get_filename_component(OUTPUT_FILE_DIR "${REL_FILE}" DIRECTORY) + file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE_DIR}") + if(EMBED_USE STREQUAL "LD") + set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.o") + add_custom_command( + OUTPUT "${OUTPUT_FILE}" + COMMAND ${EMBED_LD} -r -o "${OUTPUT_FILE}" -z noexecstack --format=binary "${REL_FILE}" + COMMAND ${EMBED_OBJCOPY} --rename-section .data=.rodata,alloc,load,readonly,data,contents "${OUTPUT_FILE}" + WORKING_DIRECTORY "${BASE_DIRECTORY}" + DEPENDS "${FILE}" + VERBATIM) + set(OUTPUT_FILE ${OUTPUT_FILE} PARENT_SCOPE) + elseif(EMBED_USE STREQUAL "CArrays") + set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${FILE}) + set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.cpp") + # reads source file contents as hex string + file(READ ${FILE} HEX_STRING HEX) + # wraps the hex string into multiple lines + embed_wrap_string(VARIABLE HEX_STRING AT_COLUMN 80) + # adds '0x' prefix and comma suffix before and after every byte respectively + string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1, " ARRAY_VALUES ${HEX_STRING}) + # removes trailing comma + string(REGEX REPLACE ", $" "" ARRAY_VALUES ${ARRAY_VALUES}) + file(WRITE "${OUTPUT_FILE}" " +#include +extern const char _binary_${OUTPUT_SYMBOL}_start[] = { ${ARRAY_VALUES} }; +extern const size_t _binary_${OUTPUT_SYMBOL}_length = sizeof(_binary_${OUTPUT_SYMBOL}_start); +") + set(OUTPUT_FILE ${OUTPUT_FILE} PARENT_SCOPE) + endif() + set(OUTPUT_SYMBOL ${OUTPUT_SYMBOL} PARENT_SCOPE) +endfunction() + +function(add_embed_library EMBED_NAME) + set(options) + set(oneValueArgs RELATIVE) + set(multiValueArgs) + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(EMBED_DIR ${CMAKE_CURRENT_BINARY_DIR}/embed/${EMBED_NAME}) + file(MAKE_DIRECTORY ${EMBED_DIR}) + message(STATUS "Embedding kernel files:") + foreach(FILE ${PARSE_UNPARSED_ARGUMENTS}) + embed_file(${FILE} ${PARSE_RELATIVE}) + list(APPEND OUTPUT_FILES ${OUTPUT_FILE}) + list(APPEND SYMBOLS ${OUTPUT_SYMBOL}) + endforeach() + message(STATUS "Generating embedding library '${EMBED_NAME}'") + generate_embed_source(${EMBED_NAME} ${EMBED_DIR} "${PARSE_RELATIVE}" SYMBOLS ${SYMBOLS} FILES ${PARSE_UNPARSED_ARGUMENTS}) + set(INTERNAL_EMBED_LIB embed_lib_${EMBED_NAME}) + if(EMBED_USE STREQUAL "LD") + add_library(${INTERNAL_EMBED_LIB} STATIC ${EMBED_FILES} ${OUTPUT_FILES}) + else() + add_library(${INTERNAL_EMBED_LIB} OBJECT ${EMBED_FILES}) + endif() + if(EMBED_USE STREQUAL "CArrays") + target_sources(${INTERNAL_EMBED_LIB} PRIVATE ${OUTPUT_FILES}) + endif() + target_include_directories(${INTERNAL_EMBED_LIB} PRIVATE "${EMBED_DIR}/include") + target_compile_options(${INTERNAL_EMBED_LIB} PRIVATE -Wno-reserved-identifier -Wno-extern-initializer -Wno-missing-variable-declarations) + set_target_properties(${INTERNAL_EMBED_LIB} PROPERTIES POSITION_INDEPENDENT_CODE On) + add_library(${EMBED_NAME} INTERFACE) + if(EMBED_USE STREQUAL "RC") + target_link_libraries(${EMBED_NAME} INTERFACE $) + elseif(EMBED_USE STREQUAL "LD") + target_link_libraries(${EMBED_NAME} INTERFACE ${INTERNAL_EMBED_LIB}) + else() + target_sources(${EMBED_NAME} INTERFACE $) + endif() + target_include_directories(${EMBED_NAME} INTERFACE "${EMBED_DIR}/include") +endfunction() + diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt new file mode 100644 index 0000000000..72549c9a4e --- /dev/null +++ b/codegen/CMakeLists.txt @@ -0,0 +1,49 @@ +cmake_minimum_required(VERSION 3.16) +project(composable_kernel_host) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) + +find_package(ROCM) +include(ROCMInstallTargets) +include(ROCMTest) + +list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake) +include(Embed) +file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS + ${CK_ROOT}/include/ck/*.hpp) +message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") +message(STATUS "RELATIVE: ${CK_ROOT}/include") +add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) + +add_definitions(-std=c++17) + +file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) +# TODO: Use object library +add_library(ck_host STATIC ${SOURCES}) +target_link_libraries(ck_host PRIVATE ck_headers) + +set_target_properties(ck_host PROPERTIES + LINKER_LANGUAGE CXX + POSITION_INDEPENDENT_CODE ON) + +target_include_directories(ck_host PUBLIC + $ +) + +add_executable(ck-template-driver driver/main.cpp) +target_link_libraries(ck-template-driver ck_host) + +rocm_install( + TARGETS ck_host ck_headers + EXPORT ck_hostTargets +) +rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) + +if(BUILD_TESTING) +add_subdirectory(test) +endif() diff --git a/codegen/driver/main.cpp b/codegen/driver/main.cpp new file mode 100644 index 0000000000..dfd513106b --- /dev/null +++ b/codegen/driver/main.cpp @@ -0,0 +1,71 @@ + +#include +#include +#include +#include +#include +#include "ck/host/device_gemm_multiple_d/operation.hpp" +#include "ck/host/stringutils.hpp" + +using ck::host::Transform; + +struct Emitters +{ + std::unordered_map()>> m; + + template + void Register(const std::string& name) + { + m[name] = [] { + auto configs = T::CreateOperations(); + + return Transform(configs, [](const auto& ops) { return ToTuple(ops); }); + }; + } + + template + static std::string ToTuple(const T& ops) + { + auto templates = Transform( + ops, [](const auto& op) { return " " + op.ToSolution().ToTemplateString(); }); + return "std::tuple<\n" + ck::host::JoinStrings(templates, ",\n") + ">"; + } + + std::string Emit(const std::string& name) { return ck::host::JoinStrings(m.at(name)(), "\n"); } + + std::vector List() const + { + return Transform(m, [](auto&& p) { return p.first; }); + } +}; + +int main(int argc, const char* argv[]) +{ + std::string prog = argv[0]; + std::vector args(argv + 1, argv + argc); + Emitters e; + e.Register( + "DeviceGemmMultipleD_Xdl_CShuffle"); + + if(args.empty() or std::any_of(args.begin(), args.end(), [](auto arg) { + return arg == "-h" or arg == "--help"; + })) + { + std::cout << "USAGE:" << std::endl; + std::cout << " " << prog << " [TEMPLATE]" << std::endl; + std::cout << std::endl; + std::cout << "FLAGS:" << std::endl; + std::cout << " -h, --help Show help" << std::endl; + std::cout << std::endl; + std::cout << "TEMPLATES:" << std::endl; + for(auto x : e.List()) + std::cout << " " << x << std::endl; + std::cout << std::endl; + return 0; + } + + for(auto name : args) + std::cout << e.Emit(name) << std::endl; + + return 0; +} diff --git a/codegen/include/ck/host/device_gemm_multiple_d.hpp b/codegen/include/ck/host/device_gemm_multiple_d.hpp new file mode 100644 index 0000000000..88e040db53 --- /dev/null +++ b/codegen/include/ck/host/device_gemm_multiple_d.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +struct Problem +{ + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + bool TransA = false; + bool TransB = false; + bool TransE = false; + std::vector DsTrans = {}; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType EDataType = DataType::Half; + std::vector DsDataType = {}; + std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string CDEElementOp = "ck::Tuple<>"; + + std::string GetIncludeHeader() const; + + std::vector GetSolutions(const std::string& arch) const; +}; + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp new file mode 100644 index 0000000000..f9d39633ac --- /dev/null +++ b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" +#include "ck/host/operation/gemm.hpp" +#include "ck/host/device_gemm_multiple_d/problem.hpp" + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +struct Operation_Xdl_CShuffle +{ + static std::vector> CreateOperations(); + static std::vector CreateOperations(const Problem& prob); + TensorDesc A{}; + TensorDesc B{}; + DataType acc = DataType::Float; + DataType cs_type = DataType::Half; + std::vector Ds = {}; + TensorDesc E{}; + std::string a_elem_op = PassThrough; + std::string b_elem_op = PassThrough; + std::string cde_elem_op = Bilinear; + std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default"; + operation::TileDesc tile_desc{}; + operation::BlockTransferDesc a_block_transfer{}; + operation::BlockTransferDesc b_block_transfer{}; + operation::CShuffleDesc cshuffle{}; + operation::CBlockTransferDesc c_block_transfer{}; + + Solution ToSolution() const; +}; + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp new file mode 100644 index 0000000000..f6dbc2b6e8 --- /dev/null +++ b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +struct Problem +{ + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + bool TransA = false; + bool TransB = false; + bool TransE = false; + std::vector DsTrans = {}; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType EDataType = DataType::Half; + std::vector DsDataType = {}; + std::string AElementOp = PassThrough; + std::string BElementOp = PassThrough; + std::string CDEElementOp = PassThrough; + + std::string GetIncludeHeader() const; + + std::vector GetSolutions(const std::string& arch) const; +}; + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/headers.hpp b/codegen/include/ck/host/headers.hpp new file mode 100644 index 0000000000..3da05baaaf --- /dev/null +++ b/codegen/include/ck/host/headers.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +namespace ck { +namespace host { + +std::unordered_map GetHeaders(); + +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/operation/gemm.hpp b/codegen/include/ck/host/operation/gemm.hpp new file mode 100644 index 0000000000..f587122b05 --- /dev/null +++ b/codegen/include/ck/host/operation/gemm.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck { +namespace host { +namespace operation { + +struct TileDesc +{ + int block_size = 0; + int m_per_block = 0; + int n_per_block = 0; + int k_per_block = 0; + int ak1 = 0; + int bk1 = 0; + int m_per_XDL = 0; + int n_per_XDL = 0; + int m_Xdl_per_wave = 0; + int n_Xdl_per_wave = 0; + int num_gemmk_prefetch_stage = 0; +}; +struct BlockTransferDesc +{ + std::string thread_cluster_length = ""; + std::string thread_cluster_arrange_order = ""; + std::string src_access_order = ""; + int src_vec_dim = 0; + int src_scalar_per_vector = 0; + int dst_scalar_per_vector_k1 = 0; + int lds_add_extra_dim = 0; +}; +struct CShuffleDesc +{ + int m_Xdl_per_wave_per_shuffle = 0; + int n_Xdl_per_wave_per_shuffle = 0; +}; +struct CBlockTransferDesc +{ + std::string cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl = ""; + int scalar_per_vector_n_wave_n_per_Xdl = 0; +}; + +} // namespace operation +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/stringutils.hpp b/codegen/include/ck/host/stringutils.hpp new file mode 100644 index 0000000000..01374b86c8 --- /dev/null +++ b/codegen/include/ck/host/stringutils.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace ck { +namespace host { + +template +std::string trim(const std::string& s, F f) +{ + auto start = std::find_if_not(s.begin(), s.end(), f); + auto last = std::find_if_not(s.rbegin(), std::string::const_reverse_iterator(start), f).base(); + return {start, last}; +} + +inline std::string trim(const std::string& s) +{ + return trim(s, [](unsigned char c) { return std::isspace(c); }); +} + +template +inline std::string JoinStrings(Strings strings, const std::string& delim) +{ + auto it = strings.begin(); + if(it == strings.end()) + return ""; + + auto nit = std::next(it); + return std::accumulate(nit, strings.end(), *it, [&](std::string x, std::string y) { + return std::move(x) + delim + std::move(y); + }); +} + +template +inline std::string +InterpolateString(const std::string& input, F f, std::string start = "${", std::string end = "}") +{ + std::string result = ""; + result.reserve(input.size()); + auto it = input.begin(); + while(it != input.end()) + { + auto next_start = std::search(it, input.end(), start.begin(), start.end()); + auto next_end = std::search(next_start, input.end(), end.begin(), end.end()); + result.append(it, next_start); + if(next_start == input.end()) + break; + if(next_end == input.end()) + { + throw std::runtime_error("Unbalanced brackets"); + } + auto r = f(next_start + start.size(), next_end); + result.append(r.begin(), r.end()); + it = next_end + end.size(); + } + return result; +} +inline std::string InterpolateString(const std::string& input, + const std::unordered_map& vars, + std::string start = "${", + std::string end = "}") +{ + return InterpolateString( + input, + [&](auto start_it, auto last_it) { + auto key = trim({start_it, last_it}); + auto it = vars.find(key); + if(it == vars.end()) + throw std::runtime_error("Unknown key: " + key); + return it->second; + }, + std::move(start), + std::move(end)); +} + +template +inline auto Transform(const Range& r, F f) -> std::vector +{ + std::vector result; + std::transform(r.begin(), r.end(), std::back_inserter(result), f); + return result; +} + +template +inline auto Transform(const Range1& r1, const Range2& r2, F f) + -> std::vector +{ + std::vector result; + assert(std::distance(r1.begin(), r1.end()) == std::distance(r2.begin(), r2.end())); + std::transform(r1.begin(), r1.end(), r2.begin(), std::back_inserter(result), f); + return result; +} + +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/types.hpp b/codegen/include/ck/host/types.hpp new file mode 100644 index 0000000000..23488a66d0 --- /dev/null +++ b/codegen/include/ck/host/types.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +namespace ck { +namespace host { + +struct Solution +{ + + Solution() = default; + Solution(std::string str, std::unordered_map values); + std::string ToTemplateString() const; + std::string GetTemplateParameter(const std::string& name) const; + template + T GetTemplateParameter(const std::string& name) const + { + T result; + std::stringstream ss(GetTemplateParameter(name)); + ss >> result; + return result; + } + + private: + std::string template_str; + std::unordered_map template_values; +}; + +enum class DataType +{ + Half, + Float, + Int8, + Int32 +}; + +std::string ToString(DataType dt); + +enum class Layout +{ + Row, + Column +}; + +std::string ToString(Layout dl); + +enum class GemmType +{ + Default +}; + +std::string ToString(GemmType gt); + +struct TensorDesc +{ + DataType element; + Layout layout; +}; + +std::string SequenceStr(const std::vector& v); + +std::string MakeTuple(const std::vector& v); + +template +const std::string S = SequenceStr({xs...}); + +constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough"; +constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear"; + +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/utils.hpp b/codegen/include/ck/host/utils.hpp new file mode 100644 index 0000000000..e8785a456f --- /dev/null +++ b/codegen/include/ck/host/utils.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +namespace ck { +namespace host { + +std::size_t integer_divide_ceil(std::size_t x, std::size_t y); + +const std::unordered_set& get_xdlop_archs(); + +} // namespace host +} // namespace ck diff --git a/codegen/src/device_gemm_multiple_d.cpp b/codegen/src/device_gemm_multiple_d.cpp new file mode 100644 index 0000000000..ec25afc0f9 --- /dev/null +++ b/codegen/src/device_gemm_multiple_d.cpp @@ -0,0 +1,33 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_gemm_multiple_d/problem.hpp" +#include "ck/host/device_gemm_multiple_d/operation.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +std::string Problem::GetIncludeHeader() const +{ + return "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"; +} + +std::vector Problem::GetSolutions(const std::string& arch) const +{ + if(get_xdlop_archs().count(arch) == 0) + return {}; + auto ops = ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle::CreateOperations(*this); + std::vector result; + std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { + return op.ToSolution(); + }); + return result; +} + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck \ No newline at end of file diff --git a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp new file mode 100644 index 0000000000..9e397497ee --- /dev/null +++ b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp @@ -0,0 +1,295 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_gemm_multiple_d/operation.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +static std::string GetGemmSpec(const std::size_t m, + const std::size_t n, + const std::size_t k, + const std::size_t m_per_block, + const std::size_t n_per_block, + const std::size_t k_per_block) +{ + std::string spec = ""; + if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0) + spec += "M"; + if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0) + spec += "N"; + if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0) + spec += "K"; + if(spec == "") + return "ck::tensor_operation::device::GemmSpecialization::Default"; + + return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; +} + +static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } + +std::vector Operation_Xdl_CShuffle::CreateOperations(const Problem& prob) +{ + std::vector result; + + std::vector tile_descriptions = { + // clang-format off +// Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| NumGemmK| +// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch| +// | | | | | | | | Wave| Wave| Stage| +// | | | | | | | | | | | + { 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1}, + { 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, 1}, + { 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1}, + { 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1}, + { 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, 1}, + { 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1}, + { 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1}, + { 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1}, + // clang-format on + }; + + std::vector a_block_descriptions_rowmajor = { + // clang-format off +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| +// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + // clang-format on + }; + + std::vector a_block_descriptions_colmajor = { + // clang-format off +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| +// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + // clang-format on + {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, + }; + + std::vector b_block_descriptions_rowmajor = { + // clang-format off +// BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| +// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + // clang-format on + }; + + std::vector b_block_descriptions_colmajor = { + // clang-format off +// BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| +// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + // clang-format on + }; + + std::vector cshuffle_descriptions = { + // clang-format off +// CShuffle| CShuffle| +// MXdlPerWave| NXdlPerWave| +// PerShuffle| PerShuffle| +// | | + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + // clang-format on + }; + + std::vector c_block_descriptions = { + // clang-format off +// CBlockTransferClusterLengths| CBlockTransfer +// _MBlock_MWaveMPerXdl| ScalarPerVector +// _NBlock_NWaveNPerXdl| _NWaveNPerXdl +// | + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 4>, 8}, + { S<1, 16, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + // clang-format on + }; + + const auto a_block_descriptions = + prob.TransA ? a_block_descriptions_colmajor : a_block_descriptions_rowmajor; + const auto b_block_descriptions = + prob.TransB ? b_block_descriptions_colmajor : b_block_descriptions_rowmajor; + + assert(tile_descriptions.size() == a_block_descriptions.size()); + assert(tile_descriptions.size() == b_block_descriptions.size()); + assert(tile_descriptions.size() == cshuffle_descriptions.size()); + assert(tile_descriptions.size() == c_block_descriptions.size()); + + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Xdl_CShuffle x; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b_block_transfer = b_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; + x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; + x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)}; + x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) { + return TensorDesc{dt, ToLayout(trans)}; + }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; + x.gemm_specialization = GetGemmSpec(prob.M, + prob.N, + prob.K, + x.tile_desc.m_per_block, + x.tile_desc.n_per_block, + x.tile_desc.k_per_block); + result.push_back(x); + } + return result; +} + +std::vector> Operation_Xdl_CShuffle::CreateOperations() +{ + std::vector problems; + for(bool TransA : {true, false}) + for(bool TransB : {true, false}) + { + Problem prob; + prob.TransA = TransA; + prob.TransB = TransB; + problems.push_back(prob); + } + return Transform(problems, [](const Problem& p) { return CreateOperations(p); }); +} + +static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = + "ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<${LayoutA}, ${LayoutB}, " + "${LayoutDs}, ${LayoutE}, ${ADataType}, ${BDataType}, ${AccDataType}, ${CShuffleDataType}, " + "${DsDataType}, ${EDataType}, ${AElementwiseOperation}, ${BElementwiseOperation}, " + "${CDEElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, " + "${MPerBlock}, ${NPerBlock}, ${KPerBlock}, ${AK1}, ${BK1}, ${MPerXDL}, ${NPerXDL}, " + "${MXdlPerWave}, ${NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, " + "${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, " + "${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, " + "${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, " + "${BBlockTransferThreadClusterLengths_BK0_N_BK1}, ${BBlockTransferThreadClusterArrangeOrder}, " + "${BBlockTransferSrcAccessOrder}, ${BBlockTransferSrcVectorDim}, " + "${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, " + "${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " + "${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, " + "${CDEBlockTransferScalarPerVector_NPerBlock}>"; + +Solution Operation_Xdl_CShuffle::ToSolution() const +{ + std::unordered_map values = { + {"LayoutA", ToString(this->A.layout)}, + {"LayoutB", ToString(this->B.layout)}, + {"LayoutDs", + MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.layout); }))}, + {"LayoutE", ToString(this->E.layout)}, + {"ADataType", ToString(this->A.element)}, + {"BDataType", ToString(this->B.element)}, + {"AccDataType", ToString(this->acc)}, + {"CShuffleDataType", ToString(this->cs_type)}, + {"DsDataType", + MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.element); }))}, + {"EDataType", ToString(this->E.element)}, + {"AElementwiseOperation", this->a_elem_op}, + {"BElementwiseOperation", this->b_elem_op}, + {"CDEElementwiseOperation", this->cde_elem_op}, + {"GemmSpecialization", this->gemm_specialization}, + {"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)}, + {"BlockSize", std::to_string(this->tile_desc.block_size)}, + {"MPerBlock", std::to_string(this->tile_desc.m_per_block)}, + {"NPerBlock", std::to_string(this->tile_desc.n_per_block)}, + {"KPerBlock", std::to_string(this->tile_desc.k_per_block)}, + {"AK1", std::to_string(this->tile_desc.ak1)}, + {"BK1", std::to_string(this->tile_desc.bk1)}, + {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)}, + {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)}, + {"MXdlPerWave", std::to_string(this->tile_desc.m_Xdl_per_wave)}, + {"NXdlPerWave", std::to_string(this->tile_desc.n_Xdl_per_wave)}, + {"ABlockTransferThreadClusterLengths_AK0_M_AK1", + this->a_block_transfer.thread_cluster_length}, + {"ABlockTransferThreadClusterArrangeOrder", + this->a_block_transfer.thread_cluster_arrange_order}, + {"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order}, + {"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)}, + {"ABlockTransferSrcScalarPerVector", + std::to_string(this->a_block_transfer.src_scalar_per_vector)}, + {"ABlockTransferDstScalarPerVector_AK1", + std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)}, + {"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)}, + {"BBlockTransferThreadClusterLengths_BK0_N_BK1", + this->b_block_transfer.thread_cluster_length}, + {"BBlockTransferThreadClusterArrangeOrder", + this->b_block_transfer.thread_cluster_arrange_order}, + {"BBlockTransferSrcAccessOrder", this->b_block_transfer.src_access_order}, + {"BBlockTransferSrcVectorDim", std::to_string(this->b_block_transfer.src_vec_dim)}, + {"BBlockTransferSrcScalarPerVector", + std::to_string(this->b_block_transfer.src_scalar_per_vector)}, + {"BBlockTransferDstScalarPerVector_BK1", + std::to_string(this->b_block_transfer.dst_scalar_per_vector_k1)}, + {"BBlockLdsExtraN", std::to_string(this->b_block_transfer.lds_add_extra_dim)}, + {"CShuffleMXdlPerWavePerShuffle", + std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)}, + {"CShuffleNXdlPerWavePerShuffle", + std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)}, + {"CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock", + this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, + {"CDEBlockTransferScalarPerVector_NPerBlock", + std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + }; + + return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values), + std::move(values)}; +} + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck diff --git a/codegen/src/headers.cpp b/codegen/src/headers.cpp new file mode 100644 index 0000000000..6fcb94cdbd --- /dev/null +++ b/codegen/src/headers.cpp @@ -0,0 +1,17 @@ +#include "ck/host/headers.hpp" +#include "ck_headers.hpp" + +namespace ck { +namespace host { + +const std::string config_header = ""; + +std::unordered_map GetHeaders() +{ + auto headers = ck_headers(); + headers.insert(std::make_pair("ck/config.h", config_header)); + return headers; +} + +} // namespace host +} // namespace ck \ No newline at end of file diff --git a/codegen/src/types.cpp b/codegen/src/types.cpp new file mode 100644 index 0000000000..d43df73f33 --- /dev/null +++ b/codegen/src/types.cpp @@ -0,0 +1,63 @@ +#include "ck/host/types.hpp" +#include "ck/host/stringutils.hpp" +#include +#include + +namespace ck { +namespace host { + +Solution::Solution(std::string str, std::unordered_map values) + : template_str(std::move(str)), template_values(std::move(values)) +{ +} + +std::string Solution::ToTemplateString() const { return this->template_str; } +std::string Solution::GetTemplateParameter(const std::string& name) const +{ + return this->template_values.at(name); +} + +std::string ToString(DataType dt) +{ + switch(dt) + { + case DataType::Float: return "float"; + case DataType::Half: return "ck::half_t"; + case DataType::Int8: return "int8_t"; + case DataType::Int32: return "int32_t"; + } + throw std::runtime_error("Incorrect data type"); +} + +std::string ToString(Layout dl) +{ + switch(dl) + { + case Layout::Row: return "ck::tensor_layout::gemm::RowMajor"; + case Layout::Column: return "ck::tensor_layout::gemm::ColumnMajor"; + } + throw std::runtime_error("Incorrect layout"); +} + +std::string ToString(GemmType gt) +{ + switch(gt) + { + case GemmType::Default: return "ck::tensor_operation::device::GemmSpecialization::Default"; + } + throw std::runtime_error("Incorrect gemm type"); +} + +std::string SequenceStr(const std::vector& v) +{ + return "ck::Sequence<" + + JoinStrings(Transform(v, [](int x) { return std::to_string(x); }), ", ") + ">"; +} + +std::string MakeTuple(const std::vector& v) +{ + return "ck::Tuple<" + JoinStrings(v, ", ") + ">"; +} + +} // namespace host +} // namespace ck diff --git a/codegen/src/utils.cpp b/codegen/src/utils.cpp new file mode 100644 index 0000000000..cd6700c489 --- /dev/null +++ b/codegen/src/utils.cpp @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/utils.hpp" + +namespace ck { +namespace host { + +std::size_t integer_divide_ceil(std::size_t x, std::size_t y) +{ + return (x + y - std::size_t{1}) / y; +} + +const std::unordered_set& get_xdlop_archs() +{ + static std::unordered_set supported_archs{"gfx90a", "gfx908", "gfx940", "gfx942"}; + return supported_archs; +} + +} // namespace host +} // namespace ck diff --git a/codegen/test/CMakeLists.txt b/codegen/test/CMakeLists.txt new file mode 100644 index 0000000000..897cce1c94 --- /dev/null +++ b/codegen/test/CMakeLists.txt @@ -0,0 +1,11 @@ + +list(APPEND CMAKE_PREFIX_PATH /opt/rocm) +add_subdirectory(rtc) + +file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) +foreach(TEST_SRC ${TEST_SRCS}) +get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) +rocm_add_test_executable(test_host_${BASE_NAME} ${TEST_SRC}) +target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host) +target_include_directories(test_host_${BASE_NAME} PUBLIC include()) +endforeach() diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp new file mode 100644 index 0000000000..17b659993a --- /dev/null +++ b/codegen/test/gemm_multiple_d.cpp @@ -0,0 +1,185 @@ +#include "ck/host/device_gemm_multiple_d/problem.hpp" +#include "ck/host/device_gemm_multiple_d/operation.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include +#include +#include +#include +#include +#include +#include + +using half = _Float16; +// using half = __fp16; + +std::vector get_headers_for_test() +{ + std::vector result; + auto hs = ck::host::GetHeaders(); + std::transform( + hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file { + return {p.first, p.second}; + }); + return result; +} + +template +rtc::buffer generate_buffer(std::size_t n, std::size_t seed = 0) +{ + rtc::buffer result(n); + std::mt19937 gen(seed); + std::uniform_real_distribution dis(-1.0); + std::generate(result.begin(), result.end(), [&] { return dis(gen); }); + return result; +} + +template +bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) +{ + return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) { + return fabs(x - y) < atol + rtol * fabs(y); + }); +} + +std::string classify(double x) +{ + switch(std::fpclassify(x)) + { + case FP_INFINITE: return "inf"; + case FP_NAN: return "nan"; + case FP_NORMAL: return "normal"; + case FP_SUBNORMAL: return "subnormal"; + case FP_ZERO: return "zero"; + default: return "unknown"; + } +} + +template +void print_classification(const Buffer& x) +{ + std::unordered_set result; + for(const auto& i : x) + result.insert(classify(i)); + for(const auto& c : result) + std::cout << c << ", "; + std::cout << std::endl; +} + +template +void print_statistics(const Buffer& x) +{ + std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", "; + std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", "; + double num_elements = x.size(); + auto mean = + std::accumulate(x.begin(), x.end(), double{0.0}, std::plus{}) / num_elements; + auto stddev = std::sqrt( + std::accumulate(x.begin(), + x.end(), + double{0.0}, + [&](double r, double v) { return r + std::pow((v - mean), 2.0); }) / + num_elements); + std::cout << "Mean: " << mean << ", "; + std::cout << "StdDev: " << stddev << "\n"; +} + +template +void print_preview(const Buffer& x) +{ + if(x.size() <= 10) + { + std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; }); + } + else + { + std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; }); + std::cout << "..., "; + std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; }); + } + std::cout << std::endl; +} + +template +struct check_all +{ + rtc::buffer data{}; + bool operator()(const rtc::buffer& x) + { + if(data.empty()) + { + data = x; + return true; + } + if(std::any_of(x.begin(), x.end(), [](double y) { return std::isnan(y); })) + return false; + return allclose(data, x); + } +}; + +template +auto report(const Solution& solution, bool pass) +{ + return test::make_predicate(solution.ToTemplateString(), [=] { return pass; }); +} + +const std::string gemm_compile_check = R"__ck__( +#include <${include}> + +extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) { + using G = ${template}; + constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})), + ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})), + ck::make_tuple(), + ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n}))); + + static_assert(desc.IsValid(), "Invalid ck gemm."); + + if constexpr(desc.IsValid()) + { + ${template}::Run(desc, + a, + b, + ck::make_tuple(), + c); + } +} + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + ck::host::device_gemm_multiple_d::Problem prob; + prob.M = 1024; + prob.N = 1024; + prob.K = 1024; + check_all check; + auto a = to_gpu(generate_buffer(1024 * 1024, 0)); + auto b = to_gpu(generate_buffer(1024 * 1024, 1)); + auto c = to_gpu(generate_buffer(1024 * 1024, 2)); + + for(auto solution : prob.GetSolutions("gfx90a")) + { + auto src = ck::host::InterpolateString(gemm_compile_check, + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}}); + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + options.kernel_name = "f"; + auto k = rtc::compile_kernel(srcs, options); + auto block_size = solution.GetTemplateParameter("BlockSize"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) * + ck::host::integer_divide_ceil(prob.N, n_per_block); + k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data()); + CHECK(report(solution, check(rtc::from_gpu(c)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/include/test.hpp b/codegen/test/include/test.hpp new file mode 100644 index 0000000000..c3e38d6002 --- /dev/null +++ b/codegen/test/include/test.hpp @@ -0,0 +1,848 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __linux__ +#include +#endif + +#ifndef MIGRAPHX_GUARD_TEST_TEST_HPP +#define MIGRAPHX_GUARD_TEST_TEST_HPP + +namespace test { +// clang-format off +// NOLINTNEXTLINE +#define TEST_FOREACH_BINARY_OPERATORS(m) \ + m(==, equal) \ + m(!=, not_equal) \ + m(<=, less_than_equal) \ + m(>=, greater_than_equal) \ + m(<, less_than) \ + m(>, greater_than) \ + m(and, and_op) \ + m(or, or_op) +// clang-format on + +// clang-format off +// NOLINTNEXTLINE +#define TEST_FOREACH_UNARY_OPERATORS(m) \ + m(not, not_op) +// clang-format on + +// NOLINTNEXTLINE +#define TEST_EACH_BINARY_OPERATOR_OBJECT(op, name) \ + struct name \ + { \ + static std::string as_string() { return #op; } \ + template \ + static decltype(auto) call(T&& x, U&& y) \ + { \ + return x op y; \ + } \ + }; + +// NOLINTNEXTLINE +#define TEST_EACH_UNARY_OPERATOR_OBJECT(op, name) \ + struct name \ + { \ + static std::string as_string() { return #op; } \ + template \ + static decltype(auto) call(T&& x) \ + { \ + return op x; \ + } \ + }; + +TEST_FOREACH_BINARY_OPERATORS(TEST_EACH_BINARY_OPERATOR_OBJECT) +TEST_FOREACH_UNARY_OPERATORS(TEST_EACH_UNARY_OPERATOR_OBJECT) + +struct nop +{ + static std::string as_string() { return ""; } + template + static auto call(T&& x) + { + return static_cast(x); + } +}; + +struct function +{ + static std::string as_string() { return ""; } + template + static decltype(auto) call(T&& x) + { + return x(); + } +}; + +template +Stream& stream_range(Stream& s, Iterator start, Iterator last); + +template +inline Stream& operator<<(Stream& s, std::nullptr_t) +{ + s << "nullptr"; + return s; +} + +template {}>::type> +inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.begin(), v.end())) +{ + s << "{ "; + stream_range(s, v.begin(), v.end()); + s << "}"; + return s; +} + +template +inline Stream& stream_range(Stream& s, Iterator start, Iterator last) +{ + if(start != last) + { + s << *start; + std::for_each(std::next(start), last, [&](auto&& x) { s << ", " << x; }); + } + return s; +} + +template +const T& get_value(const T& x) +{ + return x; +} + +template +struct lhs_expression; + +template +lhs_expression make_lhs_expression(T&& lhs); + +template +lhs_expression make_lhs_expression(T&& lhs, Operator); + +// NOLINTNEXTLINE +#define TEST_EXPR_BINARY_OPERATOR(op, name) \ + template \ + auto operator op(const V& rhs2) const \ + { \ + return make_expression(*this, rhs2, name{}); /* NOLINT */ \ + } + +// NOLINTNEXTLINE +#define TEST_EXPR_UNARY_OPERATOR(op, name) \ + auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ } + +template +struct expression +{ + T lhs; + U rhs; + + friend std::ostream& operator<<(std::ostream& s, const expression& self) + { + s << self.lhs << " " << Operator::as_string() << " " << self.rhs; + return s; + } + + friend decltype(auto) get_value(const expression& e) { return e.value(); } + + decltype(auto) value() const { return Operator::call(get_value(lhs), get_value(rhs)); }; + + TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR) + TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR) +}; + +// TODO: Remove rvalue references +template +expression make_expression(T&& rhs, U&& lhs, Operator) +{ + return {std::forward(rhs), std::forward(lhs)}; +} + +// TODO: Remove rvalue reference +template +lhs_expression make_lhs_expression(T&& lhs) +{ + return lhs_expression{std::forward(lhs)}; +} + +template +lhs_expression make_lhs_expression(T&& lhs, Operator) +{ + return lhs_expression{std::forward(lhs)}; +} + +template +struct lhs_expression +{ + T lhs; + explicit lhs_expression(T e) : lhs(e) {} + + friend std::ostream& operator<<(std::ostream& s, const lhs_expression& self) + { + std::string op = Operator::as_string(); + if(not op.empty()) + s << Operator::as_string() << " "; + s << self.lhs; + return s; + } + + friend decltype(auto) get_value(const lhs_expression& e) { return e.value(); } + + decltype(auto) value() const { return Operator::call(get_value(lhs)); } + + TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR) + TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR) + +// NOLINTNEXTLINE +#define TEST_LHS_REOPERATOR(op) \ + template \ + auto operator op(const U& rhs) const \ + { \ + return make_lhs_expression(lhs op rhs); \ + } + TEST_LHS_REOPERATOR(+) + TEST_LHS_REOPERATOR(-) + TEST_LHS_REOPERATOR(*) + TEST_LHS_REOPERATOR(/) + TEST_LHS_REOPERATOR(%) + TEST_LHS_REOPERATOR(&) + TEST_LHS_REOPERATOR(|) + TEST_LHS_REOPERATOR(^) +}; + +template +struct predicate +{ + std::string msg; + F f; + + friend std::ostream& operator<<(std::ostream& s, const predicate& self) + { + s << self.msg; + return s; + } + + decltype(auto) operator()() const { return f(); } + + operator decltype(auto)() const { return f(); } +}; + +template +auto make_predicate(const std::string& msg, F f) +{ + return make_lhs_expression(predicate{msg, f}, function{}); +} + +inline std::string as_string(bool x) +{ + if(x) + return "true"; + return "false"; +} + +template +std::string as_string(const T& x) +{ + std::stringstream ss; + ss << x; + return ss.str(); +} + +template +std::string as_string(Iterator start, Iterator last) +{ + std::stringstream ss; + stream_range(ss, start, last); + return ss.str(); +} + +template +auto make_function(const std::string& name, F f) +{ + return [=](auto&&... xs) { + std::vector args = {as_string(xs)...}; + return make_predicate(name + "(" + as_string(args.begin(), args.end()) + ")", + [=] { return f(xs...); }); + }; +} + +struct capture +{ + template + auto operator->*(const T& x) const + { + return make_lhs_expression(x); + } + + template + auto operator->*(const lhs_expression& x) const + { + return x; + } +}; + +enum class color +{ + reset = 0, + bold = 1, + underlined = 4, + fg_red = 31, + fg_green = 32, + fg_yellow = 33, + fg_blue = 34, + fg_default = 39, + bg_red = 41, + bg_green = 42, + bg_yellow = 43, + bg_blue = 44, + bg_default = 49 +}; +inline std::ostream& operator<<(std::ostream& os, const color& c) +{ +#ifndef _WIN32 + static const bool use_color = isatty(STDOUT_FILENO) != 0; + if(use_color) + return os << "\033[" << static_cast(c) << "m"; +#else + (void)c; +#endif + return os; +} + +inline std::atomic& failures() +{ + // NOLINTNEXTLINE + static std::atomic f = 0; + return f; +} + +template +void failed(T x, const char* msg, const char* func, const char* file, int line, F f) +{ + if(not bool(x.value())) + { + failures()++; + std::cout << func << std::endl; + std::cout << file << ":" << line << ":" << std::endl; + std::cout << color::bold << color::fg_red << " FAILED: " << color::reset << msg << " " + << "[ " << x << " ]" << std::endl; + f(); + } +} + +template +bool throws(F f) +{ + try + { + f(); + return false; + } + catch(...) + { + return true; + } +} + +template +bool throws(F f, const std::string& msg = "") +{ + try + { + f(); + return false; + } + catch(const Exception& ex) + { + return std::string(ex.what()).find(msg) != std::string::npos; + } +} + +template +auto within_abs(T px, U py, double ptol = 1e-6f) +{ + return make_function("near", [](auto x, auto y, auto tol) { return std::abs(x - y) < tol; })( + px, py, ptol); +} + +// This implements the basic globbing algorithm where `*` matches any number +// of characters(including none) and `?` matches any single character. It +// doesnt support character classes. +// +// This is a simple recursive implementation that scans the string where the +// string and pattern matches. When a `*` is found in the pattern, the +// `glob_match` function is called recursively to compare the rest of the +// pattern to the rest of the string. If the recursive call returns true, +// then we have a match. However, if it returns false, then we advance one +// character and call the recusrsive call again. This is referred to as a +// star-loop, which will consume zero or more characters. +// +// This simple recursive implementation works well for short string and +// patterns with few stars. First, it is unlikely to use many stars to glob +// test names. Secondly, using many stars is still signficantly faster than +// using the equivalent std::regex, which has a much slower time complexity. +template +bool glob_match(Iterator1 start, Iterator1 last, Iterator2 pattern_start, Iterator2 pattern_last) +{ + std::tie(start, pattern_start) = + std::mismatch(start, last, pattern_start, pattern_last, [](auto c, auto m) { + if(m == '?') + return true; + // We need a loop for star, so bail and handle the loop below + if(m == '*') + return false; + return c == m; + }); + // If there is no more pattern then return true if there is no more string to match + if(pattern_start == pattern_last) + return start == last; + // If the pattern is not a star then its a mismatch + if(*pattern_start != '*') + return false; + // Multiple stars are the same as a single star so skip over multiple stars + pattern_start = std::find_if(pattern_start, pattern_last, [](auto c) { return c != '*'; }); + // If the star is at the end then return true + if(pattern_start == pattern_last) + return true; + // star-loop: match the rest of the pattern and text + while(not glob_match(start, last, pattern_start, pattern_last) and start != last) + start++; + // If the string is empty then it means a match was never found + return start != last; +} + +using string_map = std::unordered_map>; + +template +string_map generic_parse(std::vector as, Keyword keyword) +{ + string_map result; + + std::string flag; + for(auto&& x : as) + { + auto f = keyword(x); + if(f.empty()) + { + result[flag].push_back(x); + } + else + { + flag = f.front(); + result[flag]; // Ensure the flag exists + flag = f.back(); + } + } + return result; +} + +using test_case = std::function; + +inline auto& get_test_cases() +{ + // NOLINTNEXTLINE + static std::vector> cases; + return cases; +} + +inline void add_test_case(std::string name, test_case f) +{ + get_test_cases().emplace_back(std::move(name), std::move(f)); +} + +struct auto_register_test_case +{ + template + auto_register_test_case(const char* name, F f) noexcept + { + add_test_case(name, f); + } +}; + +struct failure_error +{ +}; + +[[noreturn]] inline void fail() { throw failure_error{}; } + +struct driver +{ + driver() + { + add_flag({"--help", "-h"}, "Show help"); + add_flag({"--list", "-l"}, "List all test cases"); + add_flag({"--continue", "-c"}, "Continue after failure"); + add_flag({"--quiet", "-q"}, "Don't print out extra output"); + } + struct argument + { + std::vector flags = {}; + std::string help = ""; + int nargs = 1; + }; + + void add_arg(const std::vector& flags, const std::string& help = "") + { + arguments.push_back(argument{flags, help, 1}); + } + + void add_flag(const std::vector& flags, const std::string& help = "") + { + arguments.push_back(argument{flags, help, 0}); + } + + static void wrap(std::ostream& os, + const std::string& text, + const std::string& prefix = "", + unsigned int line_length = 80) + { + std::istringstream iss(text); + std::string line = prefix; + do + { + std::string word; + iss >> word; + if(line.length() + word.length() > line_length) + { + os << line << std::endl; + line = prefix; + } + line += word + " "; + } while(iss); + if(not line.empty()) + os << line << std::endl; + } + + void show_help(const std::string& exe) const + { + const std::string prefix = " "; + std::cout << std::endl; + std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl; + std::cout << " "; + std::cout << exe << " ... " << std::endl; + std::cout << std::endl; + + std::cout << color::fg_yellow << "ARGS:" << color::reset << std::endl; + std::cout << " "; + std::cout << color::fg_green << "..." << color::reset; + std::cout << std::endl; + + wrap(std::cout, + "Test cases to run. A test case can be either the exact test case name or a glob. A " + "glob expression uses a '*' to select zero or more characters or a '?' to select any " + "single character.", + prefix + prefix); + + std::cout << std::endl; + std::cout << color::fg_yellow << "OPTIONS:" << color::reset << std::endl; + for(auto&& arg : arguments) + { + std::cout << color::fg_green; + std::string arg_prefix = prefix; + for(const std::string& a : arg.flags) + { + std::cout << arg_prefix; + std::cout << a; + arg_prefix = ", "; + } + std::cout << color::reset << std::endl; + wrap(std::cout, arg.help, prefix + prefix); + } + } + + std::ostream& out() const + { + struct null_buffer : std::streambuf + { + virtual int overflow(int c) override { return c; } + }; + static null_buffer buffer; + static std::ostream null_stream(&buffer); + if(quiet) + return null_stream; + return std::cout; + } + + string_map parse(int argc, const char* argv[]) const + { + std::vector args(argv + 1, argv + argc); + string_map keys; + for(auto&& arg : arguments) + { + for(auto&& flag : arg.flags) + { + keys[flag] = {arg.flags.front()}; + if(arg.nargs == 0) + keys[flag].push_back(""); + } + } + auto result = generic_parse(args, [&](auto&& s) -> std::vector { + if(keys.count(s) > 0) + return keys[s]; + else + return {}; + }); + result["__exe__"].push_back(argv[0]); + return result; + } + + static std::string create_command(const string_map& args) + { + std::stringstream ss; + ss << args.at("__exe__").front(); + if(args.count("") > 0) + { + for(auto&& arg : args.at("")) + ss << " \"" << arg << "\""; + } + for(auto&& p : args) + { + if(p.first == "__exe__") + continue; + if(p.first.empty()) + continue; + ss << " " << p.first; + for(auto&& arg : p.second) + ss << " \"" << arg << "\""; + } + return ss.str(); + } + + static std::string fork(const std::string& name, string_map args) + { + std::string msg; + args[""] = {name}; + args.erase("--continue"); + args["--quiet"]; + auto cmd = create_command(args); + auto r = std::system(cmd.c_str()); // NOLINT + if(r != 0) + msg = "Exited with " + std::to_string(r); + return msg; + } + + static std::vector> glob_tests(const std::string& pattern) + { + std::vector> result; + std::copy_if(get_test_cases().begin(), + get_test_cases().end(), + std::back_inserter(result), + [&](auto&& p) { + return glob_match( + p.first.begin(), p.first.end(), pattern.begin(), pattern.end()); + }); + return result; + } + + void run_test_case(const std::string& name, const test_case& f, const string_map& args) + { + ran++; + out() << color::fg_green << "[ RUN ] " << color::reset << color::bold << name + << color::reset << std::endl; + std::string msg; + auto start = std::chrono::steady_clock::now(); + if(args.count("--continue") > 0) + { + msg = fork(name, args); + } + else + { + try + { + failures() = 0; + f(); + } + // cppcheck-suppress migraphx-EmptyCatchStatement + catch(const failure_error&) + { + } + } + auto finish = std::chrono::steady_clock::now(); + auto elapsed_ms = + std::chrono::duration_cast>(finish - start) + .count(); + if(msg.empty() and failures() != 0) + { + if(failures() == 1) + msg = "Test failure"; + else + msg = std::to_string(failures()) + " test failures"; + } + if(msg.empty()) + { + out() << color::fg_green << "[ COMPLETE ] " << color::reset; + } + else + { + failed.push_back(name); + out() << color::fg_red << "[ FAILED ] " << color::reset; + } + out() << color::bold << name << color::reset; + out() << color::fg_blue << " (" << elapsed_ms << "ms)" << color::reset; + if(not msg.empty()) + out() << ": " << color::fg_yellow << msg << color::reset; + out() << std::endl; + } + + void run(int argc, const char* argv[]) + { + auto args = parse(argc, argv); + if(args.count("--help") > 0) + { + show_help(args.at("__exe__").front()); + return; + } + if(args.count("--list") > 0) + { + for(auto&& tc : get_test_cases()) + out() << tc.first << std::endl; + return; + } + + if(args.count("--quiet") > 0) + quiet = true; + + auto cases = args[""]; + if(cases.empty()) + { + for(auto&& tc : get_test_cases()) + run_test_case(tc.first, tc.second, args); + } + else + { + std::unordered_map m(get_test_cases().begin(), + get_test_cases().end()); + + for(auto&& iname : cases) + { + std::vector> found_cases; + for(auto&& pattern : get_case_names(iname)) + { + auto f = m.find(pattern); + if(f == m.end()) + { + found_cases = glob_tests(pattern); + } + else + { + found_cases.push_back(*f); + } + } + if(found_cases.empty()) + { + out() << color::fg_red << "[ ERROR ] Test case '" << iname << "' not found." + << color::reset << std::endl; + failed.push_back(iname); + } + for(auto&& p : found_cases) + run_test_case(p.first, p.second, args); + } + } + out() << color::fg_green << "[==========] " << color::fg_yellow << ran << " tests ran" + << color::reset << std::endl; + if(not failed.empty()) + { + out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << failed.size() + << " tests failed" << color::reset << std::endl; + for(auto&& name : failed) + out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << name + << color::reset << std::endl; + std::exit(1); + } + } + + std::function(const std::string&)> get_case_names = + [](const std::string& name) -> std::vector { return {name}; }; + std::vector arguments = {}; + std::vector failed = {}; + std::size_t ran = 0; + bool quiet = false; +}; + +inline void run(int argc, const char* argv[]) +{ + driver d{}; + d.run(argc, argv); +} + +} // namespace test + +// NOLINTNEXTLINE +#define TEST_CAPTURE(...) test::capture{}->*__VA_ARGS__ + +// NOLINTNEXTLINE +#define CHECK(...) \ + test::failed( \ + TEST_CAPTURE(__VA_ARGS__), #__VA_ARGS__, __PRETTY_FUNCTION__, __FILE__, __LINE__, [] {}) + +// NOLINTNEXTLINE +#define EXPECT(...) \ + test::failed(TEST_CAPTURE(__VA_ARGS__), \ + #__VA_ARGS__, \ + __PRETTY_FUNCTION__, \ + __FILE__, \ + __LINE__, \ + &test::fail) +// NOLINTNEXTLINE +#define STATUS(...) EXPECT((__VA_ARGS__) == 0) + +// NOLINTNEXTLINE +#define TEST_CAT(x, ...) TEST_PRIMITIVE_CAT(x, __VA_ARGS__) +// NOLINTNEXTLINE +#define TEST_PRIMITIVE_CAT(x, ...) x##__VA_ARGS__ + +// NOLINTNEXTLINE +#define TEST_CASE_REGISTER(...) \ + static test::auto_register_test_case TEST_CAT(register_test_case_, __LINE__) = \ + test::auto_register_test_case(#__VA_ARGS__, &__VA_ARGS__); + +// NOLINTNEXTLINE +#define TEST_CASE(...) \ + void __VA_ARGS__(); \ + TEST_CASE_REGISTER(__VA_ARGS__) \ + void __VA_ARGS__() + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wglobal-constructors" +#endif + +#endif diff --git a/codegen/test/rtc/CMakeLists.txt b/codegen/test/rtc/CMakeLists.txt new file mode 100644 index 0000000000..441e60ca88 --- /dev/null +++ b/codegen/test/rtc/CMakeLists.txt @@ -0,0 +1,6 @@ + +find_package(hip) +file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp) +add_library(ck_rtc ${RTC_SOURCES}) +target_include_directories(ck_rtc PUBLIC include) +target_link_libraries(ck_rtc PUBLIC hip::host) diff --git a/codegen/test/rtc/include/rtc/compile_kernel.hpp b/codegen/test/rtc/include/rtc/compile_kernel.hpp new file mode 100644 index 0000000000..5a4a4b0dd6 --- /dev/null +++ b/codegen/test/rtc/include/rtc/compile_kernel.hpp @@ -0,0 +1,27 @@ +#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL +#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL + +#include +#include +#include + +namespace rtc { + +struct src_file +{ + std::filesystem::path path; + std::string_view content; +}; + +struct compile_options +{ + std::string flags = ""; + std::string kernel_name = "main"; +}; + +kernel compile_kernel(const std::vector& src, + compile_options options = compile_options{}); + +} // namespace rtc + +#endif diff --git a/codegen/test/rtc/include/rtc/hip.hpp b/codegen/test/rtc/include/rtc/hip.hpp new file mode 100644 index 0000000000..6b523382dc --- /dev/null +++ b/codegen/test/rtc/include/rtc/hip.hpp @@ -0,0 +1,78 @@ +#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP +#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP + +#include +#include +#include + +namespace rtc { + +template +struct buffer +{ + buffer() : ptr(), n(0) {} + buffer(std::shared_ptr p, std::size_t sz) : ptr(p), n(sz) {} + buffer(std::shared_ptr p, std::size_t sz) + : ptr(std::reinterpret_pointer_cast(p)), n(sz) + { + } + explicit buffer(std::size_t sz) : ptr(new T[sz]), n(sz) {} + T* begin() { return data(); } + T* end() { return data() + size(); } + const T* begin() const { return data(); } + const T* end() const { return data() + size(); } + + T& front() { return data()[0]; } + T& back() { return data()[size() - 1]; } + T& operator[](std::size_t i) { return data()[i]; } + T& at(std::size_t i) + { + if(i >= size()) + throw std::runtime_error("Out of bounds"); + return data()[i]; + } + + const T& front() const { return data()[0]; } + const T& back() const { return data()[size() - 1]; } + const T& operator[](std::size_t i) const { return data()[i]; } + const T& at(std::size_t i) const + { + if(i >= size()) + throw std::runtime_error("Out of bounds"); + return data()[i]; + } + const T* data() const { return ptr.get(); } + T* data() { return ptr.get(); } + + std::size_t size() const { return n; } + std::size_t bytes() const { return size() * sizeof(T); } + + bool empty() const { return size() == 0; } + + private: + std::shared_ptr ptr; + std::size_t n; +}; + +std::string get_device_name(); +std::string hip_error(int error); + +std::shared_ptr allocate_gpu(std::size_t sz, bool host = false); +std::shared_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false); +std::shared_ptr read_from_gpu(const void* x, std::size_t sz); + +template +buffer to_gpu(const buffer& input) +{ + return {write_to_gpu(input.data(), input.bytes()), input.size()}; +} + +template +buffer from_gpu(const buffer& input) +{ + return {read_from_gpu(input.data(), input.bytes()), input.size()}; +} + +} // namespace rtc + +#endif diff --git a/codegen/test/rtc/include/rtc/kernel.hpp b/codegen/test/rtc/include/rtc/kernel.hpp new file mode 100644 index 0000000000..9f38e90416 --- /dev/null +++ b/codegen/test/rtc/include/rtc/kernel.hpp @@ -0,0 +1,62 @@ +#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL +#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL + +#include +#include +#include +#include + +namespace rtc { + +struct kernel_argument +{ + template , + class = std::enable_if_t{}>> + kernel_argument(T&& x) : size(sizeof(U)), align(alignof(U)), data(&x) // NOLINT + { + } + std::size_t size; + std::size_t align; + void* data; +}; + +std::vector pack_args(const std::vector& args); + +struct kernel_impl; + +struct kernel +{ + kernel() = default; + kernel(const char* image, const std::string& name); + template + kernel(const std::vector& image, const std::string& name) + : kernel(reinterpret_cast(image.data()), name) + { + static_assert(sizeof(T) == 1, "Only byte types"); + } + + void launch(hipStream_t stream, + std::size_t global, + std::size_t local, + const std::vector& args) const; + + void launch(hipStream_t stream, + std::size_t global, + std::size_t local, + std::vector args) const; + + template + auto launch(hipStream_t stream, std::size_t global, std::size_t local, Ts... zs) const + { + return [=](auto&&... xs) { + launch(stream, global, local, std::vector{xs...}, zs...); + }; + } + + private: + std::shared_ptr impl; +}; +} // namespace rtc + +#endif diff --git a/codegen/test/rtc/include/rtc/manage_ptr.hpp b/codegen/test/rtc/include/rtc/manage_ptr.hpp new file mode 100644 index 0000000000..92edf12628 --- /dev/null +++ b/codegen/test/rtc/include/rtc/manage_ptr.hpp @@ -0,0 +1,55 @@ +#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER +#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER + +#include +#include + +namespace rtc { +template +struct manage_deleter +{ + template + void operator()(T* x) const + { + if(x != nullptr) + { + (void)f(x); + } + } +}; + +struct null_deleter +{ + template + void operator()(T*) const + { + } +}; + +template +using manage_ptr = std::unique_ptr>; + +template +struct element_type +{ + using type = typename T::element_type; +}; + +template +using remove_ptr = typename std:: + conditional_t{}, std::remove_pointer, element_type>::type; + +template +using shared = std::shared_ptr>; + +template +shared share(T p) +{ + return shared{std::move(p)}; +} + +#define RTC_MANAGE_PTR(T, F) rtc::manage_ptr, decltype(&F), &F> + +} // namespace rtc + +#endif diff --git a/codegen/test/rtc/include/rtc/tmp_dir.hpp b/codegen/test/rtc/include/rtc/tmp_dir.hpp new file mode 100644 index 0000000000..f0fd1f72bb --- /dev/null +++ b/codegen/test/rtc/include/rtc/tmp_dir.hpp @@ -0,0 +1,24 @@ +#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR +#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR + +#include +#include + +namespace rtc { + +struct tmp_dir +{ + std::filesystem::path path; + tmp_dir(const std::string& prefix = ""); + + void execute(const std::string& cmd) const; + + tmp_dir(tmp_dir const&) = delete; + tmp_dir& operator=(tmp_dir const&) = delete; + + ~tmp_dir(); +}; + +} // namespace rtc + +#endif diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp new file mode 100644 index 0000000000..7ea55b9328 --- /dev/null +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -0,0 +1,95 @@ +#include "rtc/hip.hpp" +#include +#include +#include +#include +#include +#include + +namespace rtc { + +template +T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0) +{ + std::ifstream is(filename, std::ios::binary | std::ios::ate); + if(nbytes == 0) + { + // if there is a non-zero offset and nbytes is not set, + // calculate size of remaining bytes to read + nbytes = is.tellg(); + if(offset > nbytes) + throw std::runtime_error("offset is larger than file size"); + nbytes -= offset; + } + if(nbytes < 1) + throw std::runtime_error("Invalid size for: " + filename); + is.seekg(offset, std::ios::beg); + + T buffer(nbytes, 0); + if(not is.read(&buffer[0], nbytes)) + throw std::runtime_error("Error reading file: " + filename); + return buffer; +} + +std::vector read_buffer(const std::string& filename, size_t offset = 0, size_t nbytes = 0) +{ + return generic_read_file>(filename, offset, nbytes); +} + +std::string read_string(const std::string& filename) +{ + return generic_read_file(filename); +} + +void write_buffer(const std::string& filename, const char* buffer, std::size_t size) +{ + std::ofstream os(filename); + os.write(buffer, size); +} +void write_buffer(const std::string& filename, const std::vector& buffer) +{ + write_buffer(filename, buffer.data(), buffer.size()); +} +void write_string(const std::string& filename, const std::string_view& buffer) +{ + write_buffer(filename, buffer.data(), buffer.size()); +} + +std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device-only"; } + +kernel compile_kernel(const std::vector& srcs, compile_options options) +{ + assert(not srcs.empty()); + tmp_dir td{"compile"}; + options.flags += " -I. -O3"; + options.flags += " -std=c++17"; + options.flags += " --offload-arch=" + get_device_name(); + std::string out; + + for(const auto& src : srcs) + { + std::filesystem::path full_path = td.path / src.path; + std::filesystem::path parent_path = full_path.parent_path(); + std::filesystem::create_directories(parent_path); + write_string(full_path.string(), src.content); + if(src.path.extension().string() == ".cpp") + { + options.flags += " -c " + src.path.filename().string(); + if(out.empty()) + out = src.path.stem().string() + ".o"; + } + } + + options.flags += " -o " + out; + td.execute(compiler() + options.flags); + + auto out_path = td.path / out; + if(not std::filesystem::exists(out_path)) + throw std::runtime_error("Output file missing: " + out); + + auto obj = read_buffer(out_path.string()); + + return kernel{obj.data(), options.kernel_name}; +} + +} // namespace rtc diff --git a/codegen/test/rtc/src/hip.cpp b/codegen/test/rtc/src/hip.cpp new file mode 100644 index 0000000000..10e38c9adb --- /dev/null +++ b/codegen/test/rtc/src/hip.cpp @@ -0,0 +1,102 @@ +#include +#include +#include +#include + +namespace rtc { + +using hip_ptr = RTC_MANAGE_PTR(void, hipFree); + +std::string hip_error(int error) { return hipGetErrorString(static_cast(error)); } + +int get_device_id() +{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + throw std::runtime_error("No device"); + return device; +} + +std::string get_device_name() +{ + hipDeviceProp_t props{}; + auto status = hipGetDeviceProperties(&props, get_device_id()); + if(status != hipSuccess) + throw std::runtime_error("Failed to get device properties"); + return props.gcnArchName; +} + +bool is_device_ptr(const void* ptr) +{ + hipPointerAttribute_t attr; + auto status = hipPointerGetAttributes(&attr, ptr); + if(status != hipSuccess) + return false; + return attr.type == hipMemoryTypeDevice; +} + +void gpu_sync() +{ + auto status = hipDeviceSynchronize(); + if(status != hipSuccess) + throw std::runtime_error("hip device synchronization failed: " + hip_error(status)); +} + +std::size_t get_available_gpu_memory() +{ + size_t free; + size_t total; + auto status = hipMemGetInfo(&free, &total); + if(status != hipSuccess) + throw std::runtime_error("Failed getting available memory: " + hip_error(status)); + return free; +} + +std::shared_ptr allocate_gpu(std::size_t sz, bool host) +{ + if(sz > get_available_gpu_memory()) + throw std::runtime_error("Memory not available to allocate buffer: " + std::to_string(sz)); + void* alloc_ptr = nullptr; + auto status = host ? hipHostMalloc(&alloc_ptr, sz) : hipMalloc(&alloc_ptr, sz); + if(status != hipSuccess) + { + if(host) + throw std::runtime_error("Gpu allocation failed: " + hip_error(status)); + else + return allocate_gpu(sz, true); + } + assert(alloc_ptr != nullptr); + std::shared_ptr result = share(hip_ptr{alloc_ptr}); + return result; +} + +std::shared_ptr write_to_gpu(const void* x, std::size_t sz, bool host) +{ + gpu_sync(); + auto result = allocate_gpu(sz, host); + assert(is_device_ptr(result.get())); + assert(not is_device_ptr(x)); + auto status = hipMemcpy(result.get(), x, sz, hipMemcpyHostToDevice); + if(status != hipSuccess) + throw std::runtime_error("Copy to gpu failed: " + hip_error(status)); + return result; +} + +std::shared_ptr read_from_gpu(const void* x, std::size_t sz) +{ + gpu_sync(); + std::shared_ptr result(new char[sz]); + assert(not is_device_ptr(result.get())); + if(not is_device_ptr(x)) + { + throw std::runtime_error( + "read_from_gpu() requires Src buffer to be on the GPU, Copy from gpu failed\n"); + } + auto status = hipMemcpy(result.get(), x, sz, hipMemcpyDeviceToHost); + if(status != hipSuccess) + throw std::runtime_error("Copy from gpu failed: " + hip_error(status)); // NOLINT + return std::static_pointer_cast(result); +} + +} // namespace rtc diff --git a/codegen/test/rtc/src/kernel.cpp b/codegen/test/rtc/src/kernel.cpp new file mode 100644 index 0000000000..f4fb19130c --- /dev/null +++ b/codegen/test/rtc/src/kernel.cpp @@ -0,0 +1,121 @@ +#include +#include +#include +#include + +// extern declare the function since hip/hip_ext.h header is broken +extern hipError_t hipExtModuleLaunchKernel(hipFunction_t, // NOLINT + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + size_t, + hipStream_t, + void**, + void**, + hipEvent_t = nullptr, + hipEvent_t = nullptr, + uint32_t = 0); + +namespace rtc { + +std::vector pack_args(const std::vector& args) +{ + std::vector kernargs; + for(auto&& arg : args) + { + std::size_t n = arg.size; + const auto* p = static_cast(arg.data); + // Insert padding + std::size_t padding = (arg.align - (kernargs.size() % arg.align)) % arg.align; + kernargs.insert(kernargs.end(), padding, 0); + kernargs.insert(kernargs.end(), p, p + n); + } + return kernargs; +} + +using hip_module_ptr = RTC_MANAGE_PTR(hipModule_t, hipModuleUnload); + +struct kernel_impl +{ + hip_module_ptr module = nullptr; + hipFunction_t fun = nullptr; +}; + +hip_module_ptr load_module(const char* image) +{ + hipModule_t raw_m; + auto status = hipModuleLoadData(&raw_m, image); + hip_module_ptr m{raw_m}; + if(status != hipSuccess) + throw std::runtime_error("Failed to load module: " + hip_error(status)); + return m; +} + +kernel::kernel(const char* image, const std::string& name) : impl(std::make_shared()) +{ + impl->module = load_module(image); + auto status = hipModuleGetFunction(&impl->fun, impl->module.get(), name.c_str()); + if(hipSuccess != status) + throw std::runtime_error("Failed to get function: " + name + ": " + hip_error(status)); +} + +void launch_kernel(hipFunction_t fun, + hipStream_t stream, + std::size_t global, + std::size_t local, + void* kernargs, + std::size_t size) +{ + assert(global > 0); + assert(local > 0); + void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, + kernargs, + HIP_LAUNCH_PARAM_BUFFER_SIZE, + &size, + HIP_LAUNCH_PARAM_END}; + + auto status = hipExtModuleLaunchKernel(fun, + global, + 1, + 1, + local, + 1, + 1, + 0, + stream, + nullptr, + reinterpret_cast(&config), + nullptr, + nullptr); + if(status != hipSuccess) + throw std::runtime_error("Failed to launch kernel: " + hip_error(status)); +} + +void kernel::launch(hipStream_t stream, + std::size_t global, + std::size_t local, + std::vector args) const +{ + assert(impl != nullptr); + void* kernargs = args.data(); + std::size_t size = args.size() * sizeof(void*); + + launch_kernel(impl->fun, stream, global, local, kernargs, size); +} + +void kernel::launch(hipStream_t stream, + std::size_t global, + std::size_t local, + const std::vector& args) const +{ + assert(impl != nullptr); + std::vector kernargs = pack_args(args); + std::size_t size = kernargs.size(); + + launch_kernel(impl->fun, stream, global, local, kernargs.data(), size); +} + +} // namespace rtc \ No newline at end of file diff --git a/codegen/test/rtc/src/tmp_dir.cpp b/codegen/test/rtc/src/tmp_dir.cpp new file mode 100644 index 0000000000..3b0f0170e8 --- /dev/null +++ b/codegen/test/rtc/src/tmp_dir.cpp @@ -0,0 +1,48 @@ +#include +#include +#include +#include +#include + +namespace rtc { +std::string random_string(std::string::size_type length) +{ + static const std::string& chars = "0123456789" + "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + + std::mt19937 rg{std::random_device{}()}; + std::uniform_int_distribution pick(0, chars.length() - 1); + + std::string str(length, 0); + std::generate(str.begin(), str.end(), [&] { return chars[pick(rg)]; }); + + return str; +} + +std::string unique_string(const std::string& prefix) +{ + auto pid = getpid(); + auto tid = std::this_thread::get_id(); + auto clk = std::chrono::steady_clock::now().time_since_epoch().count(); + std::stringstream ss; + ss << std::hex << prefix << "-" << pid << "-" << tid << "-" << clk << "-" << random_string(16); + return ss.str(); +} + +tmp_dir::tmp_dir(const std::string& prefix) + : path(std::filesystem::temp_directory_path() / + unique_string(prefix.empty() ? "ck-rtc" : "ck-rtc-" + prefix)) +{ + std::filesystem::create_directories(this->path); +} + +void tmp_dir::execute(const std::string& cmd) const +{ + std::string s = "cd " + path.string() + "; " + cmd; + std::system(s.c_str()); +} + +tmp_dir::~tmp_dir() { std::filesystem::remove_all(this->path); } + +} // namespace rtc \ No newline at end of file diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 42f8daef71..77ed9625c5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -498,6 +498,86 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD && ABlockTransferSrcVectorDim == 2) + { + if(KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of Ds + // only support RowMajor for now + bool all_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(!is_same_v) + { + all_valid = false; + } + }); + + if(!all_valid) + { + return false; + } + + // check vector store of E + // only support RowMajor for now + if constexpr(is_same_v) + { + if(NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + return true; + } + static bool IsSupportedArgument(const Argument& arg) { if(!ck::is_xdl_supported()) @@ -505,87 +585,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD && ABlockTransferSrcVectorDim == 2) - { - if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) - { - // FIXME: not rigorous - if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else - { - return false; - } - - // check vector laod of B - if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) - { - if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) - { - // FIXME: not rigorous - if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else - { - return false; - } - - // check vector load of Ds - // only support RowMajor for now - bool all_valid = true; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DLayout = remove_cvref_t>; - - if constexpr(!is_same_v) - { - all_valid = false; - } - }); - - if(!all_valid) - { - return false; - } - - // check vector store of E - // only support RowMajor for now - if constexpr(is_same_v) - { - if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) - { - return false; - } - } - else - { - return false; - } - } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + return IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_) and + GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, arg.ds_grid_desc_m_n_, arg.e_grid_desc_m_n_, @@ -708,6 +709,178 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD + struct Descriptor + { + static constexpr auto ds_tuple() + { + return transform_tuples( + [&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); }, + DsDesc{}); + } + using AGridDesc_M_K = + remove_cvref_t; + using BGridDesc_N_K = + remove_cvref_t; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_tuple()))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>; + using Block2ETileMap = remove_cvref_t; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k; + BGridDesc_N_K b_grid_desc_n_k; + DsGridDesc_M_N ds_grid_desc_m_n; + EGridDesc_M_N e_grid_desc_m_n; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map; + + // element-wise op + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CDEElementwiseOperation cde_element_op; + + // for checking vector load/store + index_t MRaw; + index_t NRaw; + index_t KRaw; + + bool has_main_k_block_loop = true; + + constexpr Descriptor(ADesc a, + BDesc b, + DsDesc ds, + EDesc e, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CDEElementwiseOperation cde_element_op_) + : a_grid_desc_m_k{DeviceOp::matrix_padder.PadADescriptor_M_K(a)}, + b_grid_desc_n_k{DeviceOp::matrix_padder.PadBDescriptor_N_K(b)}, + ds_grid_desc_m_n{transform_tuples( + [&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); }, + ds)}, + e_grid_desc_m_n{DeviceOp::matrix_padder.PadCDescriptor_M_N(e)}, + a_grid_desc_ak0_m_ak1{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k)}, + b_grid_desc_bk0_n_bk1{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock{ + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + transform_tuples( + [&](auto d) constexpr { + return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); + }, + ds))}, + e_grid_desc_mblock_mperblock_nblock_nperblock{ + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n)}, + block_2_etile_map{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n)}, + has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( + a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + cde_element_op{cde_element_op_}, + MRaw{e.GetLength(I0)}, + NRaw{e.GetLength(I1)}, + KRaw{a.GetLength(I1)} + { + } + + constexpr bool IsValid() const + { + return GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map) and + IsSupported(MRaw, NRaw, KRaw); + } + + constexpr index_t GetBlockSize() const { return BlockSize; } + + constexpr index_t GetGridSize() const + { + return block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); + } + }; + + template + static constexpr auto + make_descriptor(ADesc a, + BDesc b, + DsDesc ds, + EDesc e, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation cde_element_op = CDEElementwiseOperation{}) + { + return Descriptor( + a, b, ds, e, a_element_op, b_element_op, cde_element_op); + } + + template + __device__ static void Run(const Desc& desc, + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid) + { + __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + assert(desc.IsValid()); + if(desc.has_main_k_block_loop) + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + desc.cde_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.ds_grid_desc_mblock_mperblock_nblock_nperblock, + desc.e_grid_desc_mblock_mperblock_nblock_nperblock, + desc.block_2_etile_map); + } + else + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + desc.cde_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.ds_grid_desc_mblock_mperblock_nblock_nperblock, + desc.e_grid_desc_mblock_mperblock_nblock_nperblock, + desc.block_2_etile_map); + } + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 6266fb40f0..a89e14cbdb 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -24,10 +24,10 @@ struct BlockToCTileMap_M00_N0_M01 static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - __host__ __device__ BlockToCTileMap_M00_N0_M01() = default; + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01() = default; - __host__ __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01 = 1) + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1) : M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01)) { } @@ -51,8 +51,8 @@ struct BlockToCTileMap_M00_N0_M01 } template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const + __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const { if constexpr(DeviceCTileIndexCheck) return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); @@ -60,7 +60,7 @@ struct BlockToCTileMap_M00_N0_M01 return true; } - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const { if constexpr(DeviceCTileIndexCheck) return true; // validity check moved to kernel @@ -120,18 +120,19 @@ struct BlockToCTileMap_M00_N0_M01Adapt static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default; + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default; - __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt&) = - default; - __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt&&) = - default; - __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt& + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt( + const BlockToCTileMap_M00_N0_M01Adapt&) = default; + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt( + BlockToCTileMap_M00_N0_M01Adapt&&) = default; + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt& operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default; - __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt& + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt& operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default; - __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) + __host__ + __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) : M_(M), N_(N), M01_(M01) { #if 0 @@ -142,8 +143,9 @@ struct BlockToCTileMap_M00_N0_M01Adapt } template - __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01 = 8) + __host__ + __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8) : BlockToCTileMap_M00_N0_M01Adapt( c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01) { @@ -164,7 +166,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt } template - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } @@ -237,8 +239,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt } template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, - const CTileDim& /* c_tile_dim */) const + __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const { return true; // always valid provided that user gets grid size from CalculateGridSize() } @@ -616,7 +618,10 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt return true; // always valid provided that user gets grid size from CalculateGridSize() } - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } private: index_t M01_; @@ -674,7 +679,7 @@ struct BlockToCTileMap_M00_N00_M01_N01 return true; } - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const { if constexpr(DeviceCTileIndexCheck) return true; // validity check moved to kernel @@ -786,7 +791,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 return true; } - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const { if constexpr(DeviceCTileIndexCheck) return true; // validity check moved to kernel @@ -910,7 +915,7 @@ struct OffsettedBlockToCTileMap } template - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const { return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); } @@ -967,7 +972,7 @@ struct BlockToCTileMap_3DGrid_KSplit } template - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 15c30a0dad..c0a3d29f85 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -264,7 +264,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const BGridDesc_N_K& b_grid_desc_n_k, const DsGridDesc_M_N& ds_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n, - const Block2ETileMap& block_2_etile_map) + const Block2ETileMap&) { static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, @@ -310,10 +310,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle } // check block-to-E-tile - if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) - { - return false; - } + // if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + //{ + // return false; + //} // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // check tensor size: cannot be larger than 2GB each