mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Introduce gemm_softmax_gemm to codegen.
This commit is contained in:
@@ -31,12 +31,21 @@ file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
|
||||
##message(STATUS "SOURCE_FILES: ${SOURCES}")
|
||||
# TODO: Use object library
|
||||
add_library(ck_host STATIC ${SOURCES})
|
||||
target_link_libraries(ck_host PRIVATE ck_headers)
|
||||
add_library(composable_kernel::ck_host ALIAS ck_host)
|
||||
|
||||
set_target_properties(ck_host PROPERTIES
|
||||
LINKER_LANGUAGE CXX
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
target_include_directories(ck_host SYSTEM PRIVATE
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
# $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/library/src/jit_library/solution_instances>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/solution_instances>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/embed/ck_headers/include>
|
||||
)
|
||||
|
||||
target_link_libraries(ck_host PRIVATE $<BUILD_INTERFACE:ck_headers>)
|
||||
|
||||
target_include_directories(ck_host PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
)
|
||||
@@ -45,9 +54,18 @@ add_executable(ck-template-driver driver/main.cpp)
|
||||
target_link_libraries(ck-template-driver ck_host)
|
||||
|
||||
rocm_install(
|
||||
TARGETS ck_host ck_headers
|
||||
TARGETS ck_host
|
||||
EXPORT ck_hostTargets
|
||||
)
|
||||
rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||
|
||||
add_subdirectory(test)
|
||||
rocm_install(
|
||||
EXPORT ck_hostTargets
|
||||
FILE composable_kernelck_hostTargets.cmake
|
||||
NAMESPACE composable_kernel::
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
|
||||
)
|
||||
|
||||
if (NOT CK_BUILD_HOST_LIB)
|
||||
add_subdirectory(test)
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ck/host/types.hpp"
|
||||
#include "ck/host/operation/gemm.hpp"
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace device_batched_gemm_softmax_gemm {
|
||||
|
||||
// defines all values need for an instance of fwd conv
|
||||
struct Operation_Xdl_CShuffle
|
||||
{
|
||||
// returns a vector of instances, only given fusion operators: will use default problem spec
|
||||
static std::vector<std::vector<Operation_Xdl_CShuffle>>
|
||||
CreateOperations(const std::string& prologue, const std::string& epilogue);
|
||||
// returns a vector of instances, given a problem spec and fusion operators
|
||||
static std::vector<Operation_Xdl_CShuffle>
|
||||
CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
|
||||
TensorDesc A{};
|
||||
TensorDesc B{};
|
||||
TensorDesc B1{};
|
||||
TensorDesc C{};
|
||||
std::string a_elem_op = PassThrough;
|
||||
std::string b_elem_op = PassThrough;
|
||||
std::string b1_elem_op = PassThrough;
|
||||
std::string c_elem_op = PassThrough;
|
||||
std::string acc_elem_op = Scale;
|
||||
std::string prologue = "";
|
||||
std::string epilogue = "";
|
||||
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
|
||||
// tuning parameters
|
||||
operation::TileDescGemmSoftmaxGemm tile_desc{};
|
||||
operation::BlockTransferDesc a_block_transfer{};
|
||||
operation::BlockTransferDesc b0_block_transfer{};
|
||||
operation::BlockTransferDesc b1_block_transfer{};
|
||||
operation::CShuffleDesc cshuffle{};
|
||||
operation::CBlockTransferDesc c_block_transfer{};
|
||||
|
||||
bool mask_out_upper_triangle = false;
|
||||
|
||||
// functions to update fusion operators if provided
|
||||
void update_prologue(const std::string& prologue);
|
||||
void update_epilogue(const std::string& epilogue);
|
||||
/**constexpr**/ bool IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_);
|
||||
// returns a templated instance
|
||||
Solution ToSolution() const;
|
||||
};
|
||||
|
||||
} // namespace device_batched_gemm_softmax_gemm
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,47 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ck/host/types.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace device_batched_gemm_softmax_gemm {
|
||||
|
||||
// defines the problem specification for a GEMM operation
|
||||
struct Problem
|
||||
{
|
||||
std::size_t M = 0;
|
||||
std::size_t N = 0;
|
||||
std::size_t K = 0;
|
||||
std::size_t O = 0;
|
||||
bool TransA = false;
|
||||
bool TransB = false;
|
||||
bool TransB1 = false;
|
||||
bool TransC = false;
|
||||
DataType ADataType = DataType::Half;
|
||||
DataType BDataType = DataType::Half;
|
||||
DataType B1DataType = DataType::Half;
|
||||
DataType CDataType = DataType::Half;
|
||||
std::string AElementOp = PassThrough;
|
||||
std::string BElementOp = PassThrough;
|
||||
std::string B1ElementOp = PassThrough;
|
||||
std::string CElementOp = PassThrough;
|
||||
std::string AccElementOp = Scale;
|
||||
|
||||
// returns the correct device op file for the operation
|
||||
std::string GetIncludeHeader() const;
|
||||
|
||||
// returns a list of instances based on the problem spec and provided fusion operations
|
||||
std::vector<Solution> GetSolutions(const std::string& arch,
|
||||
const std::string& prologue,
|
||||
const std::string& epilogue) const;
|
||||
};
|
||||
|
||||
} // namespace device_batched_gemm_softmax_gemm
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
@@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle
|
||||
operation::BlockTransferDesc b_block_transfer{};
|
||||
operation::CShuffleDesc cshuffle{};
|
||||
operation::CBlockTransferDesc c_block_transfer{};
|
||||
LoopScheduler loop_scheduler{};
|
||||
PipelineVersion pipeline_version{};
|
||||
|
||||
// functions to update fusion operators if provided
|
||||
void update_prologue(const std::string& prologue);
|
||||
|
||||
@@ -23,6 +23,26 @@ struct TileDesc
|
||||
int n_Xdl_per_wave = 0;
|
||||
int num_gemmk_prefetch_stage = 0;
|
||||
};
|
||||
|
||||
struct TileDescGemmSoftmaxGemm
|
||||
{
|
||||
int block_size = 0;
|
||||
int gemm01_m_per_block = 0;
|
||||
int gemm0_n_per_block = 0;
|
||||
int gemm0_k_per_block = 0;
|
||||
int gemm1_n_per_block = 0;
|
||||
int gemm1_k_per_block = 0;
|
||||
int ak1 = 0;
|
||||
int bk1 = 0;
|
||||
int b1k1 = 0;
|
||||
int m_per_XDL = 0;
|
||||
int n_per_XDL = 0;
|
||||
int gemm0_m_Xdl_per_wave = 0;
|
||||
int gemm0_n_Xdl_per_wave = 0;
|
||||
int gemm1_n_Xdl_per_wave = 0;
|
||||
int num_gemmk_prefetch_stage = 0;
|
||||
};
|
||||
|
||||
struct BlockTransferDesc
|
||||
{
|
||||
std::string thread_cluster_length = "";
|
||||
|
||||
@@ -66,6 +66,20 @@ enum class GemmType
|
||||
};
|
||||
std::string ToString(GemmType gt);
|
||||
|
||||
enum class LoopScheduler
|
||||
{
|
||||
Default,
|
||||
Interwave,
|
||||
};
|
||||
std::string ToString(LoopScheduler ls);
|
||||
|
||||
enum class PipelineVersion
|
||||
{
|
||||
v1,
|
||||
v2
|
||||
};
|
||||
std::string ToString(PipelineVersion pv);
|
||||
|
||||
struct TensorDesc
|
||||
{
|
||||
DataType element;
|
||||
@@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...});
|
||||
|
||||
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
|
||||
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
|
||||
constexpr const char* Scale = "ck::tensor_operation::element_wise::Scale";
|
||||
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
|
||||
38
codegen/src/device_batched_gemm_softmax_gemm.cpp
Normal file
38
codegen/src/device_batched_gemm_softmax_gemm.cpp
Normal file
@@ -0,0 +1,38 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
|
||||
#include "ck/host/utils.hpp"
|
||||
#include <algorithm>
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace device_batched_gemm_softmax_gemm {
|
||||
|
||||
// return the relevant device op file based on the operation
|
||||
std::string Problem::GetIncludeHeader() const
|
||||
{
|
||||
return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp";
|
||||
}
|
||||
|
||||
// returns templated instances when provided with a problem specification
|
||||
std::vector<Solution> Problem::GetSolutions(const std::string& arch,
|
||||
const std::string& prologue,
|
||||
const std::string& epilogue) const
|
||||
{
|
||||
if(get_xdlop_archs().count(arch) == 0)
|
||||
return {};
|
||||
auto ops = ck::host::device_batched_gemm_softmax_gemm::Operation_Xdl_CShuffle::CreateOperations(
|
||||
*this, prologue, epilogue); // obtains vector of instances
|
||||
std::vector<Solution> result;
|
||||
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
|
||||
return op.ToSolution(); // template instance with correct values
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace device_batched_gemm_softmax_gemm
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,412 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
#include "ck/host/utils.hpp"
|
||||
#include <cassert>
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace device_batched_gemm_softmax_gemm {
|
||||
|
||||
// calculate appropriate Gemm Specification based on input tensor dimensions
|
||||
std::string GetGemmSpec(const std::size_t m,
|
||||
const std::size_t n,
|
||||
const std::size_t k,
|
||||
const std::size_t n1,
|
||||
const std::size_t m_per_block,
|
||||
const std::size_t n_per_block,
|
||||
const std::size_t k_per_block,
|
||||
const std::size_t n1_per_block)
|
||||
{
|
||||
std::string spec = "";
|
||||
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
|
||||
spec += "M";
|
||||
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
|
||||
spec += "N";
|
||||
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
|
||||
spec += "K";
|
||||
if(integer_divide_ceil(n1, n1_per_block) * n1_per_block - n1 != 0)
|
||||
spec += "O";
|
||||
if(spec == "")
|
||||
return "ck::tensor_operation::device::GemmSpecialization::Default";
|
||||
|
||||
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
|
||||
}
|
||||
|
||||
// function to update prologue/epilogue with user provided operation
|
||||
void Operation_Xdl_CShuffle::update_prologue(const std::string& pro)
|
||||
{
|
||||
if(!prologue.empty())
|
||||
{
|
||||
this->prologue = pro;
|
||||
// TODO
|
||||
// this->cde_elem_op = "CDEElementOp";
|
||||
}
|
||||
else
|
||||
{
|
||||
this->prologue = "";
|
||||
}
|
||||
}
|
||||
|
||||
void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
|
||||
{
|
||||
if(!epilogue.empty())
|
||||
{
|
||||
this->epilogue = epi;
|
||||
// TODO
|
||||
// this->cde_elem_op = "CDEElementOp";
|
||||
}
|
||||
else
|
||||
{
|
||||
this->epilogue = "";
|
||||
}
|
||||
}
|
||||
|
||||
// accounts for all possible combinations of Row/Col major
|
||||
static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
|
||||
|
||||
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
|
||||
// instances
|
||||
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
const Problem& prob, const std::string& prologue, const std::string& epilogue)
|
||||
{
|
||||
std::vector<Operation_Xdl_CShuffle> result;
|
||||
|
||||
std::vector<operation::TileDescGemmSoftmaxGemm> tile_descriptions = {
|
||||
// clang-format off
|
||||
// Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| NumGemmK|
|
||||
// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch|
|
||||
// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage|
|
||||
// | | | | | | | | | | | Wave| Wave| Wave| |
|
||||
{ 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1},
|
||||
{ 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1},
|
||||
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1},
|
||||
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1},
|
||||
{ 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
|
||||
{ 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
|
||||
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
|
||||
{ 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
|
||||
{ 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
|
||||
{ 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
|
||||
{ 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
|
||||
{ 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
|
||||
// Padded fallback kernel
|
||||
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
|
||||
{ 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1},
|
||||
// Irregular k
|
||||
{ 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1},
|
||||
{ 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1},
|
||||
{ 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1},
|
||||
{ 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1},
|
||||
{ 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1},
|
||||
{ 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
const std::vector<operation::BlockTransferDesc> a_block_descriptions = {
|
||||
// clang-format off
|
||||
// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds|
|
||||
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
|
||||
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
|
||||
// | | | | | | |
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
// Padded fallback kernel
|
||||
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
// Irregular k
|
||||
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
|
||||
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
|
||||
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
|
||||
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
|
||||
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
|
||||
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
const std::vector<operation::BlockTransferDesc> b1_block_descriptions = {
|
||||
// clang-format off
|
||||
// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds|
|
||||
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
|
||||
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
|
||||
// | | | | | | |
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
// Padded fallback kernel
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
// Irregular k
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
std::vector<operation::CShuffleDesc> cshuffle_descriptions = {
|
||||
// clang-format off
|
||||
// CShuffle| CShuffle|
|
||||
// MXdlPerWave| NXdlPerWave|
|
||||
// PerShuffle| PerShuffle|
|
||||
// | |
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 8},
|
||||
{ 1, 4},
|
||||
{ 1, 8},
|
||||
{ 1, 4},
|
||||
// Padded fallback kernel
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
// Irregular k
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
std::vector<operation::CBlockTransferDesc> c_block_descriptions = {
|
||||
// clang-format off
|
||||
// CBlockTransferClusterLengths| CBlockTransfer
|
||||
// _MBlock_MWaveMPerXdl| ScalarPerVector
|
||||
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
|
||||
// |
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1,16>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1,16>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
// Padded fallback kernel
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
// Irregular k
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
assert(tile_descriptions.size() == a_block_descriptions.size());
|
||||
assert(tile_descriptions.size() == b1_block_descriptions.size());
|
||||
assert(tile_descriptions.size() == cshuffle_descriptions.size());
|
||||
assert(tile_descriptions.size() == c_block_descriptions.size());
|
||||
|
||||
// Put all values together into a single operation > store into the result vector
|
||||
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
|
||||
{
|
||||
Operation_Xdl_CShuffle x;
|
||||
x.tile_desc = tile_descriptions[i];
|
||||
x.a_block_transfer = a_block_descriptions[i];
|
||||
x.b0_block_transfer = a_block_descriptions[i]; // b0 same as a
|
||||
x.b1_block_transfer = b1_block_descriptions[i];
|
||||
x.cshuffle = cshuffle_descriptions[i];
|
||||
x.c_block_transfer = c_block_descriptions[i];
|
||||
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
|
||||
x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
|
||||
x.B1 = TensorDesc{prob.B1DataType, ToLayout(prob.TransB1)};
|
||||
x.C = TensorDesc{prob.CDataType, ToLayout(prob.TransC)};
|
||||
x.a_elem_op = prob.AElementOp;
|
||||
x.b_elem_op = prob.BElementOp;
|
||||
x.b1_elem_op = prob.B1ElementOp;
|
||||
x.c_elem_op = prob.CElementOp;
|
||||
x.acc_elem_op = prob.AccElementOp;
|
||||
x.gemm_specialization = GetGemmSpec(prob.M,
|
||||
prob.N,
|
||||
prob.K,
|
||||
prob.O,
|
||||
x.tile_desc.gemm01_m_per_block,
|
||||
x.tile_desc.gemm0_n_per_block,
|
||||
x.tile_desc.gemm0_k_per_block,
|
||||
x.tile_desc.gemm1_n_per_block);
|
||||
x.update_prologue(prologue);
|
||||
x.update_epilogue(epilogue);
|
||||
x.mask_out_upper_triangle = true;
|
||||
result.push_back(x);
|
||||
|
||||
x.mask_out_upper_triangle = false;
|
||||
result.push_back(x);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// set up instances when not provided with a problem specification, use default operation values and
|
||||
// all possible layout combinations
|
||||
std::vector<std::vector<Operation_Xdl_CShuffle>>
|
||||
Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue)
|
||||
{
|
||||
Problem prob;
|
||||
prob.TransA = false;
|
||||
prob.TransB = true;
|
||||
prob.TransB1 = false;
|
||||
prob.TransC = false;
|
||||
|
||||
return {CreateOperations(prob, prologue, epilogue)};
|
||||
}
|
||||
|
||||
static const char* const DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate =
|
||||
"ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<${LayoutA}, "
|
||||
"${LayoutB0}, ${LayoutB1}, ${LayoutC}, ${ADataType}, ${B0DataType}, ${B1DataType}, "
|
||||
"${CDataType}, ${AccDataType}, ${CShuffleDataType}, ${AElementwiseOperation}, "
|
||||
"${B0ElementwiseOperation}, ${Acc0ElementwiseOperation}, ${B1ElementwiseOperation}, "
|
||||
"${CElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, "
|
||||
"${Gemm01MPerBlock}, ${Gemm0NPerBlock}, ${Gemm0KPerBlock}, ${Gemm1NPerBlock}, "
|
||||
"${Gemm1KPerBlock}, ${AK1}, ${BK1}, ${B1K1}, ${MPerXDL}, ${NPerXDL}, ${Gemm0MXdlPerWave}, "
|
||||
"${Gemm0NXdlPerWave}, ${Gemm1NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, "
|
||||
"${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, "
|
||||
"${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, "
|
||||
"${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, "
|
||||
"${B0BlockTransferThreadClusterLengths_BK0_N_BK1}, "
|
||||
"${B0BlockTransferThreadClusterArrangeOrder}, ${B0BlockTransferSrcAccessOrder}, "
|
||||
"${B0BlockTransferSrcVectorDim}, ${B0BlockTransferSrcScalarPerVector}, "
|
||||
"${B0BlockTransferDstScalarPerVector_BK1}, ${B0BlockLdsExtraN}, "
|
||||
"${B1BlockTransferThreadClusterLengths_BK0_N_BK1}, "
|
||||
"${B1BlockTransferThreadClusterArrangeOrder}, ${B1BlockTransferSrcAccessOrder}, "
|
||||
"${B1BlockTransferSrcVectorDim}, ${B1BlockTransferSrcScalarPerVector}, "
|
||||
"${B1BlockTransferDstScalarPerVector_BK1}, ${B1BlockLdsExtraN}, "
|
||||
"${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
|
||||
"${CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl}, "
|
||||
"${CBlockTransferScalarPerVector_NWaveNPerXdl}, ${MaskOutUpperTriangle}>";
|
||||
|
||||
// use hardcoded instances from vector of operations to substitute values into instance template
|
||||
Solution Operation_Xdl_CShuffle::ToSolution() const
|
||||
{
|
||||
std::unordered_map<std::string, std::string> values = {
|
||||
{"name",
|
||||
std::to_string(this->tile_desc.block_size) + "_" +
|
||||
std::to_string(this->tile_desc.gemm01_m_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.gemm0_n_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.gemm0_k_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.gemm1_n_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.gemm1_k_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" +
|
||||
std::to_string(this->tile_desc.b1k1) + "_" +
|
||||
std::to_string(this->tile_desc.m_per_XDL) + "_" +
|
||||
std::to_string(this->tile_desc.n_per_XDL) + "_" +
|
||||
std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" +
|
||||
std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave) + "_" +
|
||||
std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)},
|
||||
{"LayoutA", ToString(this->A.layout)},
|
||||
{"LayoutB0", ToString(this->B.layout)},
|
||||
{"LayoutB1", ToString(this->B1.layout)},
|
||||
{"LayoutC", ToString(this->C.layout)},
|
||||
{"ADataType", ToString(this->A.element)},
|
||||
{"B0DataType", ToString(this->B.element)},
|
||||
{"B1DataType", ToString(this->B1.element)},
|
||||
{"CDataType", ToString(this->C.element)},
|
||||
{"AccDataType", ToString(DataType::Float)},
|
||||
{"CShuffleDataType", ToString(DataType::Half)},
|
||||
{"AElementwiseOperation", this->a_elem_op},
|
||||
{"B0ElementwiseOperation", this->b_elem_op},
|
||||
{"Acc0ElementwiseOperation", this->acc_elem_op},
|
||||
{"B1ElementwiseOperation", this->b1_elem_op},
|
||||
{"CElementwiseOperation", this->c_elem_op},
|
||||
{"GemmSpecialization", this->gemm_specialization},
|
||||
{"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)},
|
||||
{"BlockSize", std::to_string(this->tile_desc.block_size)},
|
||||
{"Gemm01MPerBlock", std::to_string(this->tile_desc.gemm01_m_per_block)},
|
||||
{"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)},
|
||||
{"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)},
|
||||
{"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)},
|
||||
{"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)},
|
||||
{"AK1", std::to_string(this->tile_desc.ak1)},
|
||||
{"BK1", std::to_string(this->tile_desc.bk1)},
|
||||
{"B1K1", std::to_string(this->tile_desc.b1k1)},
|
||||
{"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)},
|
||||
{"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)},
|
||||
{"Gemm0MXdlPerWave", std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave)},
|
||||
{"Gemm0NXdlPerWave", std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave)},
|
||||
{"Gemm1NXdlPerWave", std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)},
|
||||
{"ABlockTransferThreadClusterLengths_AK0_M_AK1",
|
||||
this->a_block_transfer.thread_cluster_length},
|
||||
{"ABlockTransferThreadClusterArrangeOrder",
|
||||
this->a_block_transfer.thread_cluster_arrange_order},
|
||||
{"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order},
|
||||
{"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)},
|
||||
{"ABlockTransferSrcScalarPerVector",
|
||||
std::to_string(this->a_block_transfer.src_scalar_per_vector)},
|
||||
{"ABlockTransferDstScalarPerVector_AK1",
|
||||
std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)},
|
||||
{"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)},
|
||||
{"B0BlockTransferThreadClusterLengths_BK0_N_BK1",
|
||||
this->b0_block_transfer.thread_cluster_length},
|
||||
{"B0BlockTransferThreadClusterArrangeOrder",
|
||||
this->b0_block_transfer.thread_cluster_arrange_order},
|
||||
{"B0BlockTransferSrcAccessOrder", this->b0_block_transfer.src_access_order},
|
||||
{"B0BlockTransferSrcVectorDim", std::to_string(this->b0_block_transfer.src_vec_dim)},
|
||||
{"B0BlockTransferSrcScalarPerVector",
|
||||
std::to_string(this->b0_block_transfer.src_scalar_per_vector)},
|
||||
{"B0BlockTransferDstScalarPerVector_BK1",
|
||||
std::to_string(this->b0_block_transfer.dst_scalar_per_vector_k1)},
|
||||
{"B0BlockLdsExtraN", std::to_string(this->b0_block_transfer.lds_add_extra_dim)},
|
||||
{"B1BlockTransferThreadClusterLengths_BK0_N_BK1",
|
||||
this->b1_block_transfer.thread_cluster_length},
|
||||
{"B1BlockTransferThreadClusterArrangeOrder",
|
||||
this->b1_block_transfer.thread_cluster_arrange_order},
|
||||
{"B1BlockTransferSrcAccessOrder", this->b1_block_transfer.src_access_order},
|
||||
{"B1BlockTransferSrcVectorDim", std::to_string(this->b1_block_transfer.src_vec_dim)},
|
||||
{"B1BlockTransferSrcScalarPerVector",
|
||||
std::to_string(this->b1_block_transfer.src_scalar_per_vector)},
|
||||
{"B1BlockTransferDstScalarPerVector_BK1",
|
||||
std::to_string(this->b1_block_transfer.dst_scalar_per_vector_k1)},
|
||||
{"B1BlockLdsExtraN", std::to_string(this->b1_block_transfer.lds_add_extra_dim)},
|
||||
{"CShuffleMXdlPerWavePerShuffle",
|
||||
std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)},
|
||||
{"CShuffleNXdlPerWavePerShuffle",
|
||||
std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)},
|
||||
{"CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl",
|
||||
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
|
||||
{"CBlockTransferScalarPerVector_NWaveNPerXdl",
|
||||
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
|
||||
{"MaskOutUpperTriangle", std::to_string(this->mask_out_upper_triangle)},
|
||||
};
|
||||
|
||||
return Solution{InterpolateString(DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate, values),
|
||||
std::move(values)};
|
||||
}
|
||||
|
||||
} // namespace device_batched_gemm_softmax_gemm
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
@@ -62,6 +62,13 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
|
||||
// accounts for all possible combinations of Row/Col major
|
||||
static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
|
||||
|
||||
|
||||
|
||||
// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1,
|
||||
|
||||
// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
|
||||
|
||||
|
||||
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
|
||||
// instances
|
||||
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
@@ -83,6 +90,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
{ 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1},
|
||||
{ 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1},
|
||||
{ 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1},
|
||||
// Irregular tile
|
||||
{ 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -100,6 +109,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
// Irregular tile
|
||||
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -109,15 +120,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
|
||||
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
|
||||
// | | | | | | |
|
||||
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
|
||||
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
|
||||
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
|
||||
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
|
||||
// Irregular tile
|
||||
{ S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
|
||||
// clang-format on
|
||||
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
|
||||
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
|
||||
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
|
||||
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
|
||||
};
|
||||
|
||||
std::vector<operation::BlockTransferDesc> b_block_descriptions_rowmajor = {
|
||||
@@ -134,6 +147,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
|
||||
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
|
||||
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
// Irregular tile
|
||||
{ S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -151,6 +166,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
// Irregular tile
|
||||
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -167,6 +184,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
{ 1, 1},
|
||||
{ 1, 1},
|
||||
{ 1, 1},
|
||||
{ 1, 1},
|
||||
{ 1, 1},
|
||||
// clang-format on
|
||||
};
|
||||
@@ -185,6 +203,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
{ S<1, 16, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
// Irregular tile
|
||||
{ S<1, 16, 1, 4>, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -199,33 +219,44 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
assert(tile_descriptions.size() == cshuffle_descriptions.size());
|
||||
assert(tile_descriptions.size() == c_block_descriptions.size());
|
||||
|
||||
// Put all values together into a single operation > store into the result vector
|
||||
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
|
||||
const std::vector<std::tuple<LoopScheduler, PipelineVersion>> scheduler_pipeline_descriptions =
|
||||
{
|
||||
{LoopScheduler::Default, PipelineVersion::v1},
|
||||
{LoopScheduler::Interwave, PipelineVersion::v1},
|
||||
{LoopScheduler::Default, PipelineVersion::v2},
|
||||
};
|
||||
for(auto [loop_scheduler, pipeline_version] : scheduler_pipeline_descriptions)
|
||||
{
|
||||
Operation_Xdl_CShuffle x;
|
||||
x.tile_desc = tile_descriptions[i];
|
||||
x.a_block_transfer = a_block_descriptions[i];
|
||||
x.b_block_transfer = b_block_descriptions[i];
|
||||
x.cshuffle = cshuffle_descriptions[i];
|
||||
x.c_block_transfer = c_block_descriptions[i];
|
||||
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
|
||||
x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
|
||||
x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)};
|
||||
x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) {
|
||||
return TensorDesc{dt, ToLayout(trans)};
|
||||
});
|
||||
x.a_elem_op = prob.AElementOp;
|
||||
x.b_elem_op = prob.BElementOp;
|
||||
x.cde_elem_op = prob.CDEElementOp;
|
||||
x.gemm_specialization = GetGemmSpec(prob.M,
|
||||
prob.N,
|
||||
prob.K,
|
||||
x.tile_desc.m_per_block,
|
||||
x.tile_desc.n_per_block,
|
||||
x.tile_desc.k_per_block);
|
||||
x.update_prologue(prologue);
|
||||
x.update_epilogue(epilogue);
|
||||
result.push_back(x);
|
||||
// Put all values together into a single operation > store into the result vector
|
||||
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
|
||||
{
|
||||
Operation_Xdl_CShuffle x;
|
||||
x.tile_desc = tile_descriptions[i];
|
||||
x.a_block_transfer = a_block_descriptions[i];
|
||||
x.b_block_transfer = b_block_descriptions[i];
|
||||
x.cshuffle = cshuffle_descriptions[i];
|
||||
x.c_block_transfer = c_block_descriptions[i];
|
||||
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
|
||||
x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
|
||||
x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)};
|
||||
x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) {
|
||||
return TensorDesc{dt, ToLayout(trans)};
|
||||
});
|
||||
x.a_elem_op = prob.AElementOp;
|
||||
x.b_elem_op = prob.BElementOp;
|
||||
x.cde_elem_op = prob.CDEElementOp;
|
||||
x.gemm_specialization = GetGemmSpec(prob.M,
|
||||
prob.N,
|
||||
prob.K,
|
||||
x.tile_desc.m_per_block,
|
||||
x.tile_desc.n_per_block,
|
||||
x.tile_desc.k_per_block);
|
||||
x.loop_scheduler = loop_scheduler;
|
||||
x.pipeline_version = pipeline_version;
|
||||
x.update_prologue(prologue);
|
||||
x.update_epilogue(epilogue);
|
||||
result.push_back(x);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -263,7 +294,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
|
||||
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
|
||||
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
|
||||
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
|
||||
"${CDEBlockTransferScalarPerVector_NPerBlock}>";
|
||||
"${CDEBlockTransferScalarPerVector_NPerBlock}, ${LoopScheduler}, ${PipelineVersion}>";
|
||||
|
||||
// use hardcoded instances from vector of operations to substitute values into instance template
|
||||
Solution Operation_Xdl_CShuffle::ToSolution() const
|
||||
@@ -336,6 +367,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
|
||||
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
|
||||
{"CDEBlockTransferScalarPerVector_NPerBlock",
|
||||
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
|
||||
{"LoopScheduler", ToString(this->loop_scheduler)},
|
||||
{"PipelineVersion", ToString(this->pipeline_version)},
|
||||
};
|
||||
|
||||
return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values),
|
||||
|
||||
@@ -56,6 +56,26 @@ std::string ToString(GemmType gt)
|
||||
throw std::runtime_error("Incorrect gemm type");
|
||||
}
|
||||
|
||||
std::string ToString(LoopScheduler ls)
|
||||
{
|
||||
switch(ls)
|
||||
{
|
||||
case LoopScheduler::Default: return "ck::LoopScheduler::Default";
|
||||
case LoopScheduler::Interwave: return "ck::LoopScheduler::Interwave";
|
||||
}
|
||||
throw std::runtime_error("Incorrect LoopScheduler type");
|
||||
}
|
||||
|
||||
std::string ToString(PipelineVersion pv)
|
||||
{
|
||||
switch(pv)
|
||||
{
|
||||
case PipelineVersion::v1: return "ck::PipelineVersion::v1";
|
||||
case PipelineVersion::v2: return "ck::PipelineVersion::v2";
|
||||
}
|
||||
throw std::runtime_error("Incorrect PipelineVersion type");
|
||||
}
|
||||
|
||||
std::string SequenceStr(const std::vector<int>& v)
|
||||
{
|
||||
return "ck::Sequence<" +
|
||||
|
||||
@@ -15,7 +15,8 @@ std::vector<rtc::src_file> get_headers_for_test()
|
||||
auto hs = ck::host::GetHeaders();
|
||||
std::transform(
|
||||
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
|
||||
return {p.first, p.second};
|
||||
std::string sec(p.second.begin(), p.second.end());
|
||||
return {p.first, sec};
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "ck/host/device_gemm_multiple_d/problem.hpp"
|
||||
#include "ck/host/device_gemm_multiple_d/operation.hpp"
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
|
||||
#include "ck/host/headers.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
#include "ck/host/utils.hpp"
|
||||
@@ -15,13 +17,59 @@
|
||||
using half = _Float16;
|
||||
// using half = __fp16;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
const char* const disable_warning_pragma = R"__migraphx__(
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Weverything"
|
||||
${content}
|
||||
#pragma clang diagnostic pop
|
||||
)__migraphx__";
|
||||
|
||||
template <class P>
|
||||
std::string ck_disable_warnings(P p)
|
||||
{
|
||||
return ck::host::InterpolateString(disable_warning_pragma,
|
||||
{{"content", std::string{p.data(), p.size()}}});
|
||||
}
|
||||
|
||||
static std::unordered_map<std::string, std::string> create_ck_header_strings()
|
||||
{
|
||||
std::unordered_map<std::string, std::string> result;
|
||||
auto ck_headers = ck::host::GetHeaders();
|
||||
|
||||
std::transform(
|
||||
ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto& p) {
|
||||
return std::pair<std::string, std::string>(p.first, ck_disable_warnings(p.second));
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::vector<rtc::src_file> create_ck_headers()
|
||||
{
|
||||
static const auto& header_strings = create_ck_header_strings();
|
||||
std::vector<rtc::src_file> srcs;
|
||||
std::transform(
|
||||
header_strings.begin(), header_strings.end(), std::back_inserter(srcs), [&](auto& p) -> rtc::src_file {
|
||||
std::string sec(p.second.begin(), p.second.end());
|
||||
return {p.first, sec};
|
||||
});
|
||||
return srcs;
|
||||
}
|
||||
|
||||
static inline const std::vector<rtc::src_file>& ck_headers()
|
||||
{
|
||||
static const auto& headers = create_ck_headers();
|
||||
return headers;
|
||||
}
|
||||
|
||||
std::vector<rtc::src_file> get_headers_for_test()
|
||||
{
|
||||
std::vector<rtc::src_file> result;
|
||||
auto hs = ck::host::GetHeaders();
|
||||
std::transform(
|
||||
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
|
||||
return {p.first, p.second};
|
||||
std::string sec(p.second.begin(), p.second.end());
|
||||
return {p.first, sec};
|
||||
});
|
||||
return result;
|
||||
}
|
||||
@@ -130,10 +178,13 @@ const std::string gemm_compile_check = R"__ck__(
|
||||
|
||||
extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) {
|
||||
using G = ${template};
|
||||
constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})),
|
||||
ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})),
|
||||
ck::make_tuple(),
|
||||
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n})));
|
||||
constexpr auto desc =
|
||||
G::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m},
|
||||
${k})),
|
||||
ck::make_naive_tensor_descriptor(ck::make_tuple(${n},
|
||||
${k}), ck::make_tuple(1, ${n})), ck::make_tuple(),
|
||||
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m},
|
||||
${n})));
|
||||
|
||||
static_assert(desc.IsValid(), "Invalid ck gemm.");
|
||||
|
||||
@@ -163,23 +214,32 @@ TEST_CASE(test_problem_kernel)
|
||||
std::string epilogue = "";
|
||||
std::string prologue = "";
|
||||
|
||||
for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue))
|
||||
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
|
||||
std::cout << "Num solutions: " << solutions.size() << std::endl;
|
||||
for(auto i = 0; i < solutions.size(); ++i)
|
||||
{
|
||||
auto src = ck::host::InterpolateString(gemm_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)}});
|
||||
auto srcs = get_headers_for_test();
|
||||
srcs.push_back({"main.cpp", src});
|
||||
rtc::compile_options options;
|
||||
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
|
||||
auto&& solution = solutions[i];
|
||||
auto src = ck::host::InterpolateString(gemm_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)}});
|
||||
// auto srcs = get_headers_for_test();
|
||||
// srcs.push_back({"main.cpp", src});
|
||||
// rtc::compile_options options;
|
||||
// options.kernel_name = "f";
|
||||
rtc::hip_compile_options options;
|
||||
options.kernel_name = "f";
|
||||
auto k = rtc::compile_kernel(srcs, options);
|
||||
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
|
||||
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
|
||||
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
|
||||
options.additional_src_files = ck_headers();
|
||||
// auto k = rtc::compile_kernel(srcs, options);
|
||||
std::cout << src << std::endl;
|
||||
auto k = rtc::compile_hip_code_object(src, options);
|
||||
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
|
||||
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
|
||||
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
|
||||
ck::host::integer_divide_ceil(prob.N, n_per_block);
|
||||
k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data());
|
||||
|
||||
@@ -187,4 +247,34 @@ TEST_CASE(test_problem_kernel)
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE(test_gemm_softmax_gemm)
|
||||
{
|
||||
ck::host::device_batched_gemm_softmax_gemm::Problem prob;
|
||||
prob.TransA = false;
|
||||
prob.TransB = true;
|
||||
prob.TransB1 = false;
|
||||
prob.TransC = false;
|
||||
prob.M = 1024;
|
||||
prob.N = 1024;
|
||||
prob.K = 1024;
|
||||
prob.O = 1024;
|
||||
check_all<half> check;
|
||||
auto a = to_gpu(generate_buffer<half>(1024 * 1024, 0));
|
||||
auto b = to_gpu(generate_buffer<half>(1024 * 1024, 1));
|
||||
auto b1 = to_gpu(generate_buffer<half>(1024 * 1024, 2));
|
||||
auto c = to_gpu(generate_buffer<half>(1024 * 1024, 3));
|
||||
|
||||
std::string epilogue = "";
|
||||
std::string prologue = "";
|
||||
|
||||
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
|
||||
std::cout << "Num solutions: " << solutions.size() << std::endl;
|
||||
|
||||
for(auto i = 0; i < solutions.size(); ++i) {
|
||||
std::cout << "Solution " << i << std::endl;
|
||||
std::cout << solutions[i].ToTemplateString() << std::endl;
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, const char* argv[]) { test::run(argc, argv); }
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <rtc/kernel.hpp>
|
||||
#include <ck/filesystem.hpp>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
|
||||
namespace rtc {
|
||||
|
||||
@@ -19,9 +20,36 @@ struct compile_options
|
||||
std::string kernel_name = "main";
|
||||
};
|
||||
|
||||
struct hip_compile_options
|
||||
{
|
||||
std::size_t global;
|
||||
std::size_t local;
|
||||
std::string kernel_name = "kernel";
|
||||
std::string params = "";
|
||||
std::vector<src_file> additional_src_files = {};
|
||||
|
||||
/**
|
||||
* @brief Set the launch parameters but allow v to override the values
|
||||
*
|
||||
* @param v A value class which can have a "global" and/or "local" keys to override the default
|
||||
* global and local
|
||||
* @param compute_global A function used to compute the global based on the local
|
||||
* @param default_local The defaul local to use if its missing from the v parameter
|
||||
*/
|
||||
void set_launch_params(const std::function<std::size_t(std::size_t local)>& compute_global,
|
||||
std::size_t default_local = 1024);
|
||||
|
||||
void set_launch_params(std::size_t default_global, std::size_t default_local = 1024)
|
||||
{
|
||||
set_launch_params([=](auto) { return default_global; }, default_local);
|
||||
}
|
||||
};
|
||||
|
||||
kernel compile_kernel(const std::vector<src_file>& src,
|
||||
compile_options options = compile_options{});
|
||||
|
||||
kernel compile_hip_code_object(const std::string& content, hip_compile_options options);
|
||||
|
||||
} // namespace rtc
|
||||
|
||||
#endif
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace rtc {
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
#include "rtc/hip.hpp"
|
||||
#include <rtc/compile_kernel.hpp>
|
||||
#include <hip/hiprtc.h>
|
||||
#include <rtc/tmp_dir.hpp>
|
||||
#include <stdexcept>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <cassert>
|
||||
#include <deque>
|
||||
#include <numeric>
|
||||
|
||||
namespace rtc {
|
||||
|
||||
@@ -100,4 +103,345 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
|
||||
return kernel{obj.data(), options.kernel_name};
|
||||
}
|
||||
|
||||
struct hiprtc_src_file
|
||||
{
|
||||
hiprtc_src_file() = default;
|
||||
hiprtc_src_file(const src_file& s) : path(s.path.string()), content(s.content) {}
|
||||
std::string path;
|
||||
std::string content;
|
||||
template <class Self, class F>
|
||||
static auto reflect(Self& self, F f)
|
||||
{
|
||||
return pack(f(self.path, "path"), f(self.content, "content"));
|
||||
}
|
||||
};
|
||||
|
||||
std::string hiprtc_error(hiprtcResult err, const std::string& msg)
|
||||
{
|
||||
return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg));
|
||||
}
|
||||
|
||||
void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::string& ctx)
|
||||
{
|
||||
if(err != HIPRTC_SUCCESS)
|
||||
throw std::runtime_error(hiprtc_error(err, msg));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
#define MIGRAPHX_HIPRTC(...) \
|
||||
hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet")
|
||||
|
||||
#define MIGRAPHX_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg))
|
||||
|
||||
template <class F, F f> // NOLINT
|
||||
struct manage_deleter
|
||||
{
|
||||
template <class T>
|
||||
void operator()(T* x) const
|
||||
{
|
||||
if(x != nullptr)
|
||||
{
|
||||
(void)f(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class T, class F, F f> // NOLINT
|
||||
using manage_ptr = std::unique_ptr<T, manage_deleter<F, f>>;
|
||||
|
||||
#define MIGRAPHX_MANAGE_PTR(T, F) manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
|
||||
|
||||
// Workaround hiprtc's broken API
|
||||
void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); }
|
||||
using hiprtc_program_ptr = MIGRAPHX_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy);
|
||||
|
||||
template <class... Ts>
|
||||
hiprtc_program_ptr hiprtc_program_create(Ts... xs)
|
||||
{
|
||||
hiprtcProgram prog = nullptr;
|
||||
auto result = hiprtcCreateProgram(&prog, xs...);
|
||||
hiprtc_program_ptr p{prog};
|
||||
if(result != HIPRTC_SUCCESS)
|
||||
MIGRAPHX_HIPRTC_THROW(result, "Create program failed.");
|
||||
return p;
|
||||
}
|
||||
|
||||
bool starts_with(const std::string& value, const std::string& prefix)
|
||||
{
|
||||
if(prefix.size() > value.size())
|
||||
return false;
|
||||
else
|
||||
return std::equal(prefix.begin(), prefix.end(), value.begin());
|
||||
}
|
||||
|
||||
bool ends_with(const std::string& value, const std::string& suffix)
|
||||
{
|
||||
if(suffix.size() > value.size())
|
||||
return false;
|
||||
else
|
||||
return std::equal(suffix.rbegin(), suffix.rend(), value.rbegin());
|
||||
}
|
||||
|
||||
std::vector<std::string> split_string(const std::string& s, char delim)
|
||||
{
|
||||
std::vector<std::string> elems;
|
||||
std::stringstream ss(s + delim);
|
||||
std::string item;
|
||||
while(std::getline(ss, item, delim))
|
||||
{
|
||||
elems.push_back(item);
|
||||
}
|
||||
return elems;
|
||||
}
|
||||
|
||||
template <class Strings>
|
||||
inline std::string join_strings(Strings strings, const std::string& delim)
|
||||
{
|
||||
auto it = strings.begin();
|
||||
if(it == strings.end())
|
||||
return "";
|
||||
|
||||
auto nit = std::next(it);
|
||||
return std::accumulate(nit, strings.end(), *it, [&](std::string x, std::string y) {
|
||||
return std::move(x) + delim + std::move(y);
|
||||
});
|
||||
}
|
||||
|
||||
struct hiprtc_program
|
||||
{
|
||||
struct string_array
|
||||
{
|
||||
std::deque<std::string> strings{};
|
||||
std::vector<const char*> c_strs{};
|
||||
|
||||
string_array() {}
|
||||
string_array(const string_array&) = delete;
|
||||
|
||||
std::size_t size() const { return strings.size(); }
|
||||
|
||||
const char** data() { return c_strs.data(); }
|
||||
|
||||
void push_back(std::string s)
|
||||
{
|
||||
strings.push_back(std::move(s));
|
||||
c_strs.push_back(strings.back().c_str());
|
||||
}
|
||||
};
|
||||
|
||||
hiprtc_program_ptr prog = nullptr;
|
||||
string_array headers{};
|
||||
string_array include_names{};
|
||||
std::string cpp_src = "";
|
||||
std::string cpp_name = "";
|
||||
|
||||
hiprtc_program(const std::string& src, const std::string& name = "main.cpp")
|
||||
: cpp_src(src), cpp_name(name)
|
||||
{
|
||||
create_program();
|
||||
}
|
||||
|
||||
hiprtc_program(std::vector<src_file> srcs)
|
||||
{
|
||||
for(auto&& src : srcs)
|
||||
{
|
||||
if(ends_with(src.path, ".cpp"))
|
||||
{
|
||||
cpp_src = std::move(src.content);
|
||||
cpp_name = std::move(src.path);
|
||||
}
|
||||
else
|
||||
{
|
||||
headers.push_back(std::move(src.content));
|
||||
include_names.push_back(std::move(src.path));
|
||||
}
|
||||
}
|
||||
create_program();
|
||||
}
|
||||
|
||||
void create_program()
|
||||
{
|
||||
assert(not cpp_src.empty());
|
||||
assert(not cpp_name.empty());
|
||||
assert(headers.size() == include_names.size());
|
||||
prog = hiprtc_program_create(cpp_src.c_str(),
|
||||
cpp_name.c_str(),
|
||||
headers.size(),
|
||||
headers.data(),
|
||||
include_names.data());
|
||||
}
|
||||
|
||||
void compile(const std::vector<std::string>& options, bool quiet = false) const
|
||||
{
|
||||
// if(enabled(MIGRAPHX_TRACE_HIPRTC{}))
|
||||
// std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl;
|
||||
std::vector<const char*> c_options;
|
||||
std::transform(options.begin(),
|
||||
options.end(),
|
||||
std::back_inserter(c_options),
|
||||
[](const std::string& s) { return s.c_str(); });
|
||||
std::cout << "BEFORE HIPRTC COMPILE" << std::endl;
|
||||
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
|
||||
auto prog_log = log();
|
||||
if(not prog_log.empty() and not quiet)
|
||||
{
|
||||
std::cerr << prog_log << std::endl;
|
||||
}
|
||||
if(result != HIPRTC_SUCCESS)
|
||||
throw std::runtime_error("Compilation failed.");
|
||||
}
|
||||
|
||||
std::string log() const
|
||||
{
|
||||
std::size_t n = 0;
|
||||
MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n));
|
||||
if(n == 0)
|
||||
return {};
|
||||
std::string buffer(n, '\0');
|
||||
MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
|
||||
assert(buffer.back() != 0);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
std::vector<char> get_code_obj() const
|
||||
{
|
||||
std::size_t n = 0;
|
||||
MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n));
|
||||
std::vector<char> buffer(n);
|
||||
MIGRAPHX_HIPRTC(hiprtcGetCode(prog.get(), buffer.data()));
|
||||
return buffer;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<src_file> srcs,
|
||||
const std::string& params,
|
||||
const std::string& arch)
|
||||
{
|
||||
hiprtc_program prog(std::move(srcs));
|
||||
auto options = split_string(params, ' ');
|
||||
options.push_back("-DMIGRAPHX_USE_HIPRTC=1");
|
||||
if(true)
|
||||
{
|
||||
options.push_back("-DMIGRAPHX_HAS_DPP=0");
|
||||
options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1");
|
||||
options.push_back("-Wno-reserved-identifier");
|
||||
options.push_back("-Wno-unused-parameter");
|
||||
options.push_back("-Wno-gnu-line-marker");
|
||||
options.push_back("-Wno-old-style-cast");
|
||||
}
|
||||
if(true)
|
||||
options.push_back("-DMIGRAPHX_DEBUG");
|
||||
if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
|
||||
return starts_with(s, "--std=") or starts_with(s, "-std=");
|
||||
}))
|
||||
options.push_back("-std=c++17");
|
||||
options.push_back("-fno-gpu-rdc");
|
||||
options.push_back("-O3");
|
||||
options.push_back("-Wno-cuda-compat");
|
||||
options.push_back("--offload-arch=" + arch);
|
||||
prog.compile(options);
|
||||
return {prog.get_code_obj()};
|
||||
}
|
||||
|
||||
bool hip_has_flags(const std::vector<std::string>& flags)
|
||||
{
|
||||
hiprtc_program prog{" "};
|
||||
try
|
||||
{
|
||||
prog.compile(flags, true);
|
||||
return true;
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool hip_accept_non_uniform_wg()
|
||||
{
|
||||
static bool non_uniform_wg = hip_has_flags({"-fno-offload-uniform-block"});
|
||||
return non_uniform_wg;
|
||||
}
|
||||
|
||||
static std::vector<std::string> get_compiler_warnings()
|
||||
{
|
||||
std::vector<std::string> warnings = {
|
||||
"-Weverything",
|
||||
"-Wno-c++98-compat",
|
||||
"-Wno-c++98-compat-pedantic",
|
||||
"-Wno-conversion",
|
||||
"-Wno-double-promotion",
|
||||
"-Wno-exit-time-destructors",
|
||||
"-Wno-extra-semi",
|
||||
"-Wno-extra-semi-stmt",
|
||||
"-Wno-float-conversion",
|
||||
"-Wno-gnu-anonymous-struct",
|
||||
"-Wno-gnu-zero-variadic-macro-arguments",
|
||||
"-Wno-missing-prototypes",
|
||||
"-Wno-nested-anon-types",
|
||||
"-Wno-padded",
|
||||
"-Wno-shorten-64-to-32",
|
||||
"-Wno-sign-conversion",
|
||||
"-Wno-sign-compare",
|
||||
"-Wno-unused-command-line-argument",
|
||||
"-Wno-weak-vtables",
|
||||
"-Wno-c99-extensions",
|
||||
};
|
||||
|
||||
if(hip_has_flags({"-Werror", "-Wunsafe-buffer-usage"}))
|
||||
warnings.push_back("-Wno-unsafe-buffer-usage");
|
||||
return warnings;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& compiler_warnings()
|
||||
{
|
||||
static std::vector<std::string> warnings = get_compiler_warnings();
|
||||
return warnings;
|
||||
}
|
||||
|
||||
kernel compile_hip_code_object(const std::string& content, hip_compile_options options)
|
||||
{
|
||||
assert(options.global > 0);
|
||||
assert(options.local > 0);
|
||||
// assert(not options.inputs.empty());
|
||||
// assert(options.inputs.size() == options.virtual_inputs.size() or
|
||||
// options.virtual_inputs.empty());
|
||||
std::vector<src_file> srcs = options.additional_src_files;
|
||||
// Neko sranje
|
||||
// static auto kernels{::migraphx_kernels()};
|
||||
// std::transform(
|
||||
// kernels.begin(),
|
||||
// kernels.end(),
|
||||
// std::back_inserter(srcs),
|
||||
// [](const std::pair<std::string_view, std::string_view>& elem) { return src_file{elem};
|
||||
// });
|
||||
srcs.emplace_back("main.cpp", content);
|
||||
|
||||
for (auto src : srcs) {
|
||||
std::cout << src.path << std::endl;
|
||||
}
|
||||
|
||||
|
||||
// auto args_hpp =
|
||||
// generate_args_hpp(options.virtual_inputs.empty() ? options.inputs :
|
||||
// options.virtual_inputs);
|
||||
// srcs.emplace_back("args.hpp", args_hpp);
|
||||
|
||||
if(options.global % options.local != 0 and hip_accept_non_uniform_wg())
|
||||
options.params += " -fno-offload-uniform-block";
|
||||
else
|
||||
assert(options.global % options.local == 0);
|
||||
|
||||
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
|
||||
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
|
||||
options.params += " " + join_strings(compiler_warnings(), " ");
|
||||
options.params += " -ftemplate-backtrace-limit=0";
|
||||
options.params += " -Werror";
|
||||
auto cos = compile_hip_src_with_hiprtc(srcs, options.params, get_device_name());
|
||||
if(cos.size() != 1)
|
||||
std::runtime_error("No code object");
|
||||
auto& obj = cos.front();
|
||||
|
||||
return kernel{obj.data(), options.kernel_name};
|
||||
}
|
||||
|
||||
} // namespace rtc
|
||||
|
||||
Reference in New Issue
Block a user