From 59e2dc294d0556077df53abf0610f2d0739438ed Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Thu, 1 Jun 2023 18:54:52 -0500 Subject: [PATCH 1/3] Updates to ck host library API (#731) * Move functions to cpp file * Move another function to cpp file * Fix semicolon * Move solution to common.hpp * Fix compile errors * Use enum for data types * Remove -Werror * Fix header install * Fix relative path * Fix header path * Install all headers --- cmake/Embed.cmake | 8 +- cmake/EnableCompilerWarnings.cmake | 1 - library/src/jit_library/CMakeLists.txt | 32 ++- .../jit_library/include/ck/host/common.hpp | 34 +++ .../ck/host/device_gemm_multiple_d.hpp | 60 +++++ .../include/device_gemm_multiple_d.hpp | 217 ------------------ library/src/jit_library/src/common.cpp | 30 +++ .../src/device_gemm_multiple_d.cpp | 160 +++++++++++++ .../jit_library/util/make_instance_strings.py | 19 +- 9 files changed, 311 insertions(+), 250 deletions(-) create mode 100644 library/src/jit_library/include/ck/host/common.hpp create mode 100644 library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp delete mode 100644 library/src/jit_library/include/device_gemm_multiple_d.hpp create mode 100644 library/src/jit_library/src/common.cpp create mode 100644 library/src/jit_library/src/device_gemm_multiple_d.cpp diff --git a/cmake/Embed.cmake b/cmake/Embed.cmake index 04c88974fd..3328f45e39 100644 --- a/cmake/Embed.cmake +++ b/cmake/Embed.cmake @@ -27,7 +27,7 @@ find_program(EMBED_OBJCOPY objcopy) function(generate_embed_source EMBED_NAME) set(options) set(oneValueArgs SRC HEADER RELATIVE) - set(multiValueArgs OBJECTS SYMBOLS) + set(multiValueArgs OBJECTS SYMBOLS FILES) cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -44,6 +44,7 @@ function(generate_embed_source EMBED_NAME) foreach(idx RANGE ${LEN}) list(GET PARSE_SYMBOLS ${idx} SYMBOL) list(GET PARSE_OBJECTS ${idx} OBJECT) + list(GET PARSE_FILES ${idx} FILE) set(START_SYMBOL "_binary_${SYMBOL}_start") set(END_SYMBOL "_binary_${SYMBOL}_end") string(APPEND EXTERNS " @@ -52,8 +53,7 @@ function(generate_embed_source EMBED_NAME) ") - file(RELATIVE_PATH BASE_NAME ${PARSE_RELATIVE} "${OBJECT}") - string(REGEX REPLACE ".[A-Za-z0-9_]$" "" BASE_NAME ${BASE_NAME}) + file(RELATIVE_PATH BASE_NAME ${PARSE_RELATIVE} "${FILE}") string(APPEND INIT_KERNELS " { \"${BASE_NAME}\", { ${START_SYMBOL}, ${END_SYMBOL}} }, @@ -121,7 +121,7 @@ function(add_embed_library EMBED_NAME) list(APPEND SYMBOLS ${OUTPUT_SYMBOL}) endforeach() message(STATUS "Generating embedding library ${EMBED_NAME}") - generate_embed_source(${EMBED_NAME} SRC ${SRC_FILE} HEADER ${HEADER_FILE} OBJECTS ${OUTPUT_FILES} SYMBOLS ${SYMBOLS} RELATIVE ${PARSE_RELATIVE}) + generate_embed_source(${EMBED_NAME} SRC ${SRC_FILE} HEADER ${HEADER_FILE} OBJECTS ${OUTPUT_FILES} SYMBOLS ${SYMBOLS} RELATIVE ${PARSE_RELATIVE} FILES ${PARSE_UNPARSED_ARGUMENTS}) add_library(${EMBED_NAME} STATIC ${OUTPUT_FILES} "${SRC_FILE}") target_include_directories(${EMBED_NAME} PUBLIC "$") target_compile_options(${EMBED_NAME} PRIVATE -Wno-reserved-identifier) diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 369cd0b54c..1bd697f68a 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,6 @@ else() -Wunreachable-code -Wunused -Wno-reserved-identifier - -Werror -Wsign-compare -Wno-extra-semi-stmt ) diff --git a/library/src/jit_library/CMakeLists.txt b/library/src/jit_library/CMakeLists.txt index 539b50aa8f..0e187b0f75 100644 --- a/library/src/jit_library/CMakeLists.txt +++ b/library/src/jit_library/CMakeLists.txt @@ -1,26 +1,30 @@ include(Embed) -file(GLOB_RECURSE KERNEL_FILES ${CONFIGURE_DEPENDS} +file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS ${PROJECT_SOURCE_DIR}/include/ck/*.hpp) message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") -add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${PROJECT_SOURCE_DIR}/build/include) +message(STATUS "RELATIVE: ${PROJECT_SOURCE_DIR}/include") +add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${PROJECT_SOURCE_DIR}/include) execute_process( - COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/util/make_instance_strings.py + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/util/make_instance_strings.py + ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu + ${CMAKE_CURRENT_BINARY_DIR}/solution_instances WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../tensor_operation_instance/gpu/ ) - -set(JIT_LIB_SOURCE - ${CMAKE_CURRENT_SOURCE_DIR}/include/device_gemm_multiple_d.hpp +add_library(jit_library STATIC + src/device_gemm_multiple_d.cpp + src/common.cpp ) - -add_library(jit_library STATIC ${JIT_LIB_SOURCE}) add_library(composable_kernel::jit_library ALIAS jit_library) set_target_properties(jit_library PROPERTIES LINKER_LANGUAGE CXX) -target_include_directories(jit_library PUBLIC +target_include_directories(jit_library PRIVATE $ + $ + $ + $ ) target_link_libraries(jit_library PRIVATE ck_headers) @@ -30,14 +34,8 @@ rocm_install( EXPORT jit_libraryTargets ) -set(INCLUDE_DIRS - ${PROJECT_SOURCE_DIR}/include/ck/ - ${PROJECT_SOURCE_DIR}/library/src/jit_library/include - ${PROJECT_SOURCE_DIR}/library/src/jit_library/solution_instances - ${CMAKE_CURRENT_BINARY_DIR}/embed/ck_headers/include -) - -rocm_install(DIRECTORY ${INCLUDE_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck) +rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) +rocm_install(DIRECTORY ${PROJECT_SOURCE_DIR}/include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) rocm_install( EXPORT jit_libraryTargets diff --git a/library/src/jit_library/include/ck/host/common.hpp b/library/src/jit_library/include/ck/host/common.hpp new file mode 100644 index 0000000000..8b2ceacc68 --- /dev/null +++ b/library/src/jit_library/include/ck/host/common.hpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +namespace ck { +namespace host { + +struct Solution +{ + std::string template_str; + std::size_t block_size; + std::size_t grid_size; +}; + +enum class DataType { + Half, + Float, + Int8, + Int32 +}; + +std::string ToString(DataType dt); + +std::unordered_map> GetHeaders(); + +std::size_t integer_divide_ceil(std::size_t x, std::size_t y); + +} // namespace host +} // namespace ck diff --git a/library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp b/library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp new file mode 100644 index 0000000000..c73715e1ae --- /dev/null +++ b/library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "ck/host/common.hpp" + + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +struct Problem +{ + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + bool TransA = false; + bool TransB = false; + bool TransE = false; + std::vector DsTrans = {}; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType EDataType = DataType::Half; + std::vector DsDataType = {}; + std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string CDEElementOp = "ck::Tuple<>"; + + static const std::size_t ds_layout_idx = 3; + static const std::size_t ds_data_type_idx = 9; + static const std::size_t e_data_type_idx = 10; + static const std::size_t a_elementwise_op_idx = 11; + static const std::size_t b_elementwise_op_idx = 12; + static const std::size_t ds_elementwise_op_idx = 13; + static const std::size_t gemm_spec_idx = 14; + static const std::size_t block_size_idx = 16; + static const std::size_t m_per_block_idx = 17; + static const std::size_t n_per_block_idx = 18; + static const std::size_t k_per_block_idx = 19; + + std::string GetIncludeHeader() const; + + std::vector GetSolutions(const std::string& arch) const; + +private: + std::vector GetInstances(const std::string& arch) const; + + Solution MakeSolution(std::size_t idx, const std::string& arch) const; +}; + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck diff --git a/library/src/jit_library/include/device_gemm_multiple_d.hpp b/library/src/jit_library/include/device_gemm_multiple_d.hpp deleted file mode 100644 index 821821f1f5..0000000000 --- a/library/src/jit_library/include/device_gemm_multiple_d.hpp +++ /dev/null @@ -1,217 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include "ck/solution_instances/gemm_add_add_fastgelu_instances.hpp" -#include "ck/ck.hpp" -#include "ck/utility/math.hpp" -#include "ck_headers.hpp" - - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_multiple_d { - - -struct Solution -{ - std::string template_str; - index_t block_size; - index_t grid_size; -}; - -std::string GetGemmSpec(const index_t m, - const index_t n, - const index_t k, - const index_t m_per_block, - const index_t n_per_block, - const index_t k_per_block) -{ - std::string spec = ""; - if(math::integer_divide_ceil(m, m_per_block) * m_per_block - m != 0) - spec += "M"; - if(math::integer_divide_ceil(n, n_per_block) * n_per_block - n != 0) - spec += "N"; - if(math::integer_divide_ceil(k, k_per_block) * k_per_block - k != 0) - spec += "K"; - if(spec == "") - return "ck::tensor_operation::device::GemmSpecialization::Default"; - - return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; -} - -index_t GetGridSize(const index_t m, - const index_t n, - const index_t m_per_block, - const index_t n_per_block) -{ - return math::integer_divide_ceil(m, m_per_block) * - math::integer_divide_ceil(n, n_per_block); -} - -const std::unordered_set& get_xdlop_archs() -{ - static std::unordered_set supported_archs{"gfx90a"}; - return supported_archs; -} - -struct Problem -{ - index_t M = 0; - index_t N = 0; - index_t K = 0; - bool TransA = false; - bool TransB = false; - bool TransE = false; - std::vector DsLayout = {}; - std::string ADataType = "ck::half_t"; - std::string BDataType = "ck::half_t"; - std::string EDataType = "ck::half_t"; - std::vector DsDataType = {}; - std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough"; - std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough"; - std::string CDEElementOp = "ck::Tuple<>"; - - static const index_t ds_layout_idx = 3; - static const index_t ds_data_type_idx = 9; - static const index_t e_data_type_idx = 10; - static const index_t a_elementwise_op_idx = 11; - static const index_t b_elementwise_op_idx = 12; - static const index_t ds_elementwise_op_idx = 13; - static const index_t gemm_spec_idx = 14; - static const index_t block_size_idx = 16; - static const index_t m_per_block_idx = 17; - static const index_t n_per_block_idx = 18; - static const index_t k_per_block_idx = 19; - -private: - auto GetInstances(const std::string& arch) const - { - std::vector instances; - const bool quantize = ADataType == "int8_t" and BDataType == "int8_t"; - if (get_xdlop_archs().find(arch) != get_xdlop_archs().end()) - { - instance::gemm_add_add_fastgelu_instances all_instances{}; - if(TransA and TransB) - instances = all_instances.get_col_col_instances(quantize); - else if(TransA and not TransB) - instances = all_instances.get_col_row_instances(quantize); - else if(not TransA and not TransB) - instances = all_instances.get_row_row_instances(quantize); - else - instances = all_instances.get_row_col_instances(quantize); - } - return instances; - } - - auto MakeLayoutTuple(const std::vector& layouts) const - { - std::string layout_tuple = "ck::Tuple<"; - auto it = layouts.begin(); - while(it != layouts.end()) - { - layout_tuple += *it ? "ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor"; - it = std::next(it); - if (it != layouts.end()) - layout_tuple += ", "; - } - - return layout_tuple + ">"; - } - - auto MakeTypeTuple(const std::vector& types) const - { - std::string type_tuple = "ck::Tuple<"; - auto it = types.begin(); - while(it != types.end()) - { - type_tuple += *it; - it = std::next(it); - if (it != types.end()) - type_tuple += ", "; - } - return type_tuple + ">"; - } - - auto MakeSolution(index_t idx, const std::string& arch) const - { - auto template_str = GetInstances(arch).at(idx); - std::istringstream iss(template_str); - std::vector params(std::istream_iterator{iss}, - std::istream_iterator()); - - if (ADataType == "int8_t" and BDataType == "int8_t") - { - // Change CBlockTransfer ScalarPerVector if Ds contains other types - if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == "ck::half_t"; })) - { - params[params.size() - 3] = "8"; - } - if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == "float"; })) - { - params[params.size() - 3] = "4"; - } - } - - params[a_elementwise_op_idx] = AElementOp; - params[b_elementwise_op_idx] = BElementOp; - params[ds_layout_idx] = MakeLayoutTuple(DsLayout); - params[ds_data_type_idx] = MakeTypeTuple(DsDataType); - params[ds_elementwise_op_idx] = CDEElementOp; - params[e_data_type_idx] = EDataType; - auto block_size_str = params[block_size_idx]; - auto m_per_block_str = params[m_per_block_idx]; - auto n_per_block_str = params[n_per_block_idx]; - auto k_per_block_str = params[k_per_block_idx]; - const auto block_size = std::stoi(block_size_str); - const auto m_per_block = std::stoi(m_per_block_str); - const auto n_per_block = std::stoi(n_per_block_str); - const auto k_per_block = std::stoi(k_per_block_str); - const auto grid_size = GetGridSize(M, N, m_per_block, n_per_block); - params[gemm_spec_idx] = GetGemmSpec(M, N, K, m_per_block, n_per_block, k_per_block); - - std::string str = std::accumulate(params.begin() + 1, params.end(), std::string{}, - [](const std::string& a, const std::string& b) { - return a.empty() ? b : a + ", " + b; - }); - str = params.front() + "< " + str + ">"; - - return Solution{str, block_size, grid_size}; - } - -public: - auto GetHeaders() const - { - return ck_headers(); - } - - auto GetIncludeHeader() const - { - return instance::gemm_add_add_fastgelu_instances{}.get_include_header(); - } - - auto GetSolutions(const std::string& arch) const - { - std::vector solutions; - const auto num_instances = GetInstances(arch).size(); - for (auto i = 0; i < num_instances; ++i) - { - solutions.push_back(MakeSolution(i, arch)); - } - - return solutions; - } -}; - -} // namespace device_gemm_multiple_d -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/jit_library/src/common.cpp b/library/src/jit_library/src/common.cpp new file mode 100644 index 0000000000..accd182998 --- /dev/null +++ b/library/src/jit_library/src/common.cpp @@ -0,0 +1,30 @@ + +#include "ck/host/common.hpp" +#include "ck_headers.hpp" + +namespace ck { +namespace host { + +std::string ToString(DataType dt) +{ + switch (dt) { + case DataType::Float: return "float"; + case DataType::Half: return "ck::half_t"; + case DataType::Int8: return "int8_t"; + case DataType::Int32: return "int32_t"; + } + throw std::runtime_error("Incorrect data type"); +} + +std::unordered_map> GetHeaders() +{ + return ck_headers(); +} + +std::size_t integer_divide_ceil(std::size_t x, std::size_t y) +{ + return (x + y - std::size_t{1}) / y; +} + +} // namespace host +} // namespace ck diff --git a/library/src/jit_library/src/device_gemm_multiple_d.cpp b/library/src/jit_library/src/device_gemm_multiple_d.cpp new file mode 100644 index 0000000000..6f96f92a7c --- /dev/null +++ b/library/src/jit_library/src/device_gemm_multiple_d.cpp @@ -0,0 +1,160 @@ +#include "ck/host/device_gemm_multiple_d.hpp" +#include "ck/host/common.hpp" +#include "gemm_add_add_fastgelu_instances.hpp" +#include +#include + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +std::string GetGemmSpec(const std::size_t m, + const std::size_t n, + const std::size_t k, + const std::size_t m_per_block, + const std::size_t n_per_block, + const std::size_t k_per_block) +{ + std::string spec = ""; + if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0) + spec += "M"; + if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0) + spec += "N"; + if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0) + spec += "K"; + if(spec == "") + return "ck::tensor_operation::device::GemmSpecialization::Default"; + + return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; +} + +std::size_t GetGridSize(const std::size_t m, + const std::size_t n, + const std::size_t m_per_block, + const std::size_t n_per_block) +{ + return integer_divide_ceil(m, m_per_block) * + integer_divide_ceil(n, n_per_block); +} + +const std::unordered_set& get_xdlop_archs() +{ + static std::unordered_set supported_archs{"gfx90a"}; + return supported_archs; +} + +std::vector Problem::GetInstances(const std::string& arch) const +{ + std::vector instances; + const bool quantize = ADataType == DataType::Int8 and BDataType == DataType::Int8; + if (get_xdlop_archs().find(arch) != get_xdlop_archs().end()) + { + instance::gemm_add_add_fastgelu_instances all_instances{}; + if(TransA and TransB) + instances = all_instances.get_col_col_instances(quantize); + else if(TransA and not TransB) + instances = all_instances.get_col_row_instances(quantize); + else if(not TransA and not TransB) + instances = all_instances.get_row_row_instances(quantize); + else + instances = all_instances.get_row_col_instances(quantize); + } + return instances; +} + +std::string MakeLayoutTuple(const std::vector& layouts) +{ + std::string layout_tuple = "ck::Tuple<"; + auto it = layouts.begin(); + while(it != layouts.end()) + { + layout_tuple += *it ? "ck::tensor_layout::gemm::ColumnMajor" : "ck::tensor_layout::gemm::RowMajor"; + it = std::next(it); + if (it != layouts.end()) + layout_tuple += ", "; + } + + return layout_tuple + ">"; +} + +std::string MakeTypeTuple(const std::vector& types) +{ + std::string type_tuple = "ck::Tuple<"; + auto it = types.begin(); + while(it != types.end()) + { + type_tuple += ToString(*it); + it = std::next(it); + if (it != types.end()) + type_tuple += ", "; + } + return type_tuple + ">"; +} + +Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const +{ + auto template_str = GetInstances(arch).at(idx); + std::istringstream iss(template_str); + std::vector params(std::istream_iterator{iss}, + std::istream_iterator()); + + if (ADataType == DataType::Int8 and BDataType == DataType::Int8) + { + // Change CBlockTransfer ScalarPerVector if Ds contains other types + if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Half; })) + { + params[params.size() - 3] = "8"; + } + if (std::any_of(DsDataType.begin(), DsDataType.end(), [](auto t) { return t == DataType::Float; })) + { + params[params.size() - 3] = "4"; + } + } + + params[a_elementwise_op_idx] = AElementOp; + params[b_elementwise_op_idx] = BElementOp; + params[ds_layout_idx] = MakeLayoutTuple(DsTrans); + params[ds_data_type_idx] = MakeTypeTuple(DsDataType); + params[ds_elementwise_op_idx] = CDEElementOp; + params[e_data_type_idx] = ToString(EDataType); + auto block_size_str = params[block_size_idx]; + auto m_per_block_str = params[m_per_block_idx]; + auto n_per_block_str = params[n_per_block_idx]; + auto k_per_block_str = params[k_per_block_idx]; + const std::size_t block_size = std::stoi(block_size_str); + const std::size_t m_per_block = std::stoi(m_per_block_str); + const std::size_t n_per_block = std::stoi(n_per_block_str); + const std::size_t k_per_block = std::stoi(k_per_block_str); + const std::size_t grid_size = GetGridSize(M, N, m_per_block, n_per_block); + params[gemm_spec_idx] = GetGemmSpec(M, N, K, m_per_block, n_per_block, k_per_block); + + std::string str = std::accumulate(params.begin() + 1, params.end(), std::string{}, + [](const std::string& a, const std::string& b) { + return a.empty() ? b : a + ", " + b; + }); + str = params.front() + "< " + str + ">"; + + return Solution{str, block_size, grid_size}; +} + +std::string Problem::GetIncludeHeader() const +{ + return instance::gemm_add_add_fastgelu_instances{}.get_include_header(); +} + +std::vector Problem::GetSolutions(const std::string& arch) const +{ + std::vector solutions; + const std::size_t num_instances = GetInstances(arch).size(); + for (std::size_t i = 0; i < num_instances; ++i) + { + solutions.push_back(MakeSolution(i, arch)); + } + + return solutions; +} + + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck diff --git a/library/src/jit_library/util/make_instance_strings.py b/library/src/jit_library/util/make_instance_strings.py index d06282f1f5..df6ac7d8b7 100644 --- a/library/src/jit_library/util/make_instance_strings.py +++ b/library/src/jit_library/util/make_instance_strings.py @@ -1,4 +1,4 @@ -import argparse, re, json, os +import argparse, re, json, os, sys out_file = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. @@ -10,8 +10,7 @@ out_file = """// SPDX-License-Identifier: MIT #include namespace ck {{ -namespace tensor_operation {{ -namespace device {{ +namespace host {{ namespace instance {{ struct {op_name}_instances @@ -87,8 +86,7 @@ struct {op_name}_instances }}; }} // namespace instance -}} // namespace device -}} // namespace tensor_operation +}} // namespace host }} // namespace ck """ @@ -172,8 +170,7 @@ def get_int8_instances(src, file, template_name): instances["col_row"][-1] = instances["col_row"][-1][:-1] return instances -def parse_instances(source): - out_dir = os.path.join(source, "../../../src/jit_library/solution_instances") +def parse_instances(source, out_dir): aliases = {"F16_F16_Tuple": "ck::Tuple", "Row_Row_Tuple": "ck::Tuple", "Empty_Tuple": "ck::Tuple<>", @@ -273,9 +270,9 @@ def parse_instances(source): int8_row_col_instances="\n".join(int8_instances["row_col"]), include_header=include_header)) -def run(): - source = "/code/composable_kernel/library/src/tensor_operation_instance/gpu" - parse_instances(source) +def run(args): + parse_instances(args[0], args[1]) if __name__ == '__main__': - run() \ No newline at end of file + run(sys.argv[1:]) + \ No newline at end of file From 33f88fa84ec5282dff93655e2a62ebc58849ebd9 Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 2 Jun 2023 15:15:53 -0500 Subject: [PATCH 2/3] Add missing header --- library/src/jit_library/src/common.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/library/src/jit_library/src/common.cpp b/library/src/jit_library/src/common.cpp index accd182998..067b60df68 100644 --- a/library/src/jit_library/src/common.cpp +++ b/library/src/jit_library/src/common.cpp @@ -1,6 +1,7 @@ #include "ck/host/common.hpp" #include "ck_headers.hpp" +#include namespace ck { namespace host { From 2470dcd5e4e365368c02dde3ba7c4f9bb2aa3cbe Mon Sep 17 00:00:00 2001 From: Paul Date: Fri, 2 Jun 2023 15:17:52 -0500 Subject: [PATCH 3/3] Return not found for missing component --- Config.cmake.in | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/Config.cmake.in b/Config.cmake.in index 03f299a96c..0a19915565 100644 --- a/Config.cmake.in +++ b/Config.cmake.in @@ -7,5 +7,10 @@ foreach(_comp ${composable_kernel_FIND_COMPONENTS}) set(composable_kernel_FOUND False) set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}") endif() - include("${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake") + if(EXISTS "${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake") + include("${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake") + else() + set(composable_kernel_FOUND False) + set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}") + endif() endforeach()