mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Codegen build w/CK (#1428)
* initial push * cleaned up compiler errors * removed commented code * build codegen folder only for gfx9 targets * remove separate stage for codegen tests from CI * removed commented code from CMake --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin <Illia.Silin@amd.com>
This commit is contained in:
@@ -1,6 +1,3 @@
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
project(composable_kernel_host LANGUAGES CXX HIP)
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
|
||||
@@ -8,17 +5,9 @@ set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
||||
|
||||
find_package(ROCM)
|
||||
include(ROCMInstallTargets)
|
||||
include(ROCMTest)
|
||||
|
||||
add_compile_options(-std=c++17)
|
||||
find_package(hip)
|
||||
## HIP
|
||||
set(CMAKE_HIP_PLATFORM amd)
|
||||
set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER})
|
||||
set(CMAKE_HIP_EXTENSIONS ON)
|
||||
message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}")
|
||||
add_custom_target(codegen)
|
||||
|
||||
# add include directories
|
||||
include_directories(BEFORE
|
||||
@@ -32,8 +21,9 @@ list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake)
|
||||
include(Embed)
|
||||
file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
|
||||
${CK_ROOT}/include/ck/*.hpp)
|
||||
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
|
||||
message(STATUS "RELATIVE: ${CK_ROOT}/include")
|
||||
#printouts fot debug purposes
|
||||
#message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
|
||||
#message(STATUS "RELATIVE: ${CK_ROOT}/include")
|
||||
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include)
|
||||
|
||||
file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
|
||||
|
||||
@@ -76,8 +76,11 @@ std::string SequenceStr(const std::vector<int>& v);
|
||||
|
||||
std::string MakeTuple(const std::vector<std::string>& v);
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wglobal-constructors"
|
||||
template <int... xs>
|
||||
const std::string S = SequenceStr({xs...});
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
|
||||
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "ck/host/device_gemm_multiple_d/operation.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
#include "ck/host/types.hpp"
|
||||
#include "ck/host/utils.hpp"
|
||||
#include <cassert>
|
||||
|
||||
@@ -32,11 +33,11 @@ static std::string GetGemmSpec(const std::size_t m,
|
||||
}
|
||||
|
||||
// function to update prologue/epilogue with user provided operation
|
||||
void Operation_Xdl_CShuffle::update_prologue(const std::string& prologue)
|
||||
void Operation_Xdl_CShuffle::update_prologue(const std::string& pro)
|
||||
{
|
||||
if(!prologue.empty())
|
||||
if(!pro.empty())
|
||||
{
|
||||
this->prologue = prologue;
|
||||
this->prologue = pro;
|
||||
this->cde_elem_op = "CDEElementOp";
|
||||
}
|
||||
else
|
||||
@@ -45,11 +46,11 @@ void Operation_Xdl_CShuffle::update_prologue(const std::string& prologue)
|
||||
}
|
||||
}
|
||||
|
||||
void Operation_Xdl_CShuffle::update_epilogue(const std::string& epilogue)
|
||||
void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
|
||||
{
|
||||
if(!epilogue.empty())
|
||||
if(!epi.empty())
|
||||
{
|
||||
this->epilogue = epilogue;
|
||||
this->epilogue = epi;
|
||||
this->cde_elem_op = "CDEElementOp";
|
||||
}
|
||||
else
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp"
|
||||
#include <iostream>
|
||||
#include "ck/host/stringutils.hpp"
|
||||
#include "ck/host/types.hpp"
|
||||
#include "ck/host/utils.hpp"
|
||||
#include <cassert>
|
||||
|
||||
@@ -11,34 +12,15 @@ namespace ck {
|
||||
namespace host {
|
||||
namespace conv {
|
||||
|
||||
// calculate appropriate Gemm Specification based on input tensor dimensions
|
||||
// NOTE: in CK, MNKPadding is always used for forward convolution
|
||||
static std::string GetGemmSpec(const std::size_t m,
|
||||
const std::size_t n,
|
||||
const std::size_t k,
|
||||
const std::size_t m_per_block,
|
||||
const std::size_t n_per_block,
|
||||
const std::size_t k_per_block)
|
||||
{
|
||||
std::string spec = "";
|
||||
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
|
||||
spec += "M";
|
||||
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
|
||||
spec += "N";
|
||||
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
|
||||
spec += "K";
|
||||
if(spec == "")
|
||||
return "ck::tensor_operation::device::GemmSpecialization::Default";
|
||||
|
||||
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
|
||||
}
|
||||
// NOTE: in CK, MNKPadding is always used for forward convolution, so didn't
|
||||
// add GemmSpec function here
|
||||
|
||||
// function to update prologue/epilogue with user provided operation
|
||||
void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& prologue)
|
||||
void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& pro)
|
||||
{
|
||||
if(!prologue.empty())
|
||||
if(!pro.empty())
|
||||
{
|
||||
this->prologue = prologue;
|
||||
this->prologue = pro;
|
||||
this->cde_elem_op = "CDEElementOp";
|
||||
}
|
||||
else
|
||||
@@ -47,11 +29,11 @@ void Operation_Conv_Fwd_Xdl_Cshuffle::update_prologue(const std::string& prologu
|
||||
}
|
||||
}
|
||||
|
||||
void Operation_Conv_Fwd_Xdl_Cshuffle::update_epilogue(const std::string& epilogue)
|
||||
void Operation_Conv_Fwd_Xdl_Cshuffle::update_epilogue(const std::string& epi)
|
||||
{
|
||||
if(!epilogue.empty())
|
||||
if(!epi.empty())
|
||||
{
|
||||
this->epilogue = epilogue;
|
||||
this->epilogue = epi;
|
||||
this->cde_elem_op = "CDEElementOp";
|
||||
}
|
||||
else
|
||||
|
||||
@@ -4,7 +4,10 @@
|
||||
namespace ck {
|
||||
namespace host {
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wglobal-constructors"
|
||||
const std::string config_header = "";
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
std::unordered_map<std::string_view, std::string_view> GetHeaders()
|
||||
{
|
||||
|
||||
@@ -4,7 +4,9 @@ file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp)
|
||||
foreach(TEST_SRC ${TEST_SRCS})
|
||||
set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP)
|
||||
get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE)
|
||||
rocm_add_test_executable(test_host_${BASE_NAME} ${TEST_SRC})
|
||||
add_executable(test_host_${BASE_NAME} ${TEST_SRC})
|
||||
add_dependencies(codegen test_host_${BASE_NAME})
|
||||
add_test(NAME codegen_test_${BASE_NAME} COMMAND test_host_${BASE_NAME})
|
||||
target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host)
|
||||
# target_link_libraries(test_host_${BASE_NAME} ${CK_ROOT}/build/lib/libutility.a)
|
||||
target_include_directories(test_host_${BASE_NAME} PUBLIC include())
|
||||
|
||||
@@ -92,7 +92,6 @@ struct Epilogue
|
||||
static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Y),
|
||||
static_cast<int>(prob.X)};
|
||||
ck::Array<ck::index_t, 5> d_lengths = {};
|
||||
|
||||
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
|
||||
@@ -109,7 +108,6 @@ struct Epilogue
|
||||
1,
|
||||
static_cast<int>(prob.X * prob.C),
|
||||
static_cast<int>(prob.C)};
|
||||
ck::Array<ck::index_t, 5> d_strides = {};
|
||||
|
||||
ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2};
|
||||
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
|
||||
|
||||
@@ -92,7 +92,6 @@ struct Epilogue
|
||||
static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Y),
|
||||
static_cast<int>(prob.X)};
|
||||
ck::Array<ck::index_t, 5> d_lengths = {};
|
||||
|
||||
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
|
||||
@@ -109,7 +108,6 @@ struct Epilogue
|
||||
1,
|
||||
static_cast<int>(prob.X * prob.C),
|
||||
static_cast<int>(prob.C)};
|
||||
ck::Array<ck::index_t, 5> d_strides = {};
|
||||
|
||||
ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1};
|
||||
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
|
||||
|
||||
@@ -92,7 +92,6 @@ struct Epilogue
|
||||
static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Y),
|
||||
static_cast<int>(prob.X)};
|
||||
ck::Array<ck::index_t, 5> d_lengths = {};
|
||||
|
||||
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
|
||||
@@ -109,7 +108,6 @@ struct Epilogue
|
||||
1,
|
||||
static_cast<int>(prob.X * prob.C),
|
||||
static_cast<int>(prob.C)};
|
||||
ck::Array<ck::index_t, 5> d_strides = {};
|
||||
|
||||
ck::Array<ck::index_t, 2> conv_filter_strides = {2, 2};
|
||||
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
|
||||
|
||||
@@ -92,7 +92,6 @@ struct Epilogue
|
||||
static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Y),
|
||||
static_cast<int>(prob.X)};
|
||||
ck::Array<ck::index_t, 5> d_lengths = {};
|
||||
|
||||
ck::Array<ck::index_t, 5> in_strides{static_cast<int>(prob.C),
|
||||
static_cast<int>(prob.Hi * prob.Wi * prob.G * prob.C),
|
||||
@@ -109,7 +108,6 @@ struct Epilogue
|
||||
1,
|
||||
static_cast<int>(prob.X * prob.C),
|
||||
static_cast<int>(prob.C)};
|
||||
ck::Array<ck::index_t, 5> d_strides = {};
|
||||
|
||||
ck::Array<ck::index_t, 2> conv_filter_strides = {1, 1};
|
||||
ck::Array<ck::index_t, 2> conv_filter_dilations = {1, 1};
|
||||
|
||||
@@ -118,4 +118,4 @@ void kernel::launch(hipStream_t stream,
|
||||
launch_kernel(impl->fun, stream, global, local, kernargs.data(), size);
|
||||
}
|
||||
|
||||
} // namespace rtc
|
||||
} // namespace rtc
|
||||
|
||||
@@ -45,4 +45,4 @@ void tmp_dir::execute(const std::string& cmd) const
|
||||
|
||||
tmp_dir::~tmp_dir() { std::filesystem::remove_all(this->path); }
|
||||
|
||||
} // namespace rtc
|
||||
} // namespace rtc
|
||||
|
||||
Reference in New Issue
Block a user