Introduce gemm_softmax_gemm to codegen.

This commit is contained in:
Mirza Halilcevic
2024-09-25 08:22:07 +00:00
parent 3528a523ff
commit d43cd4ad32
52 changed files with 2108 additions and 187 deletions

View File

@@ -128,6 +128,8 @@ list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/ll
message("GPU_TARGETS= ${GPU_TARGETS}")
option(CK_BUILD_HOST_LIB, "Only build the CK JIT Helper Library" OFF)
find_package(hip)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
@@ -254,6 +256,7 @@ elseif(CK_PARALLEL_COMPILE_JOBS)
message(WARNING "Job pooling is only available with Ninja generators.")
endif()
if (NOT CK_BUILD_HOST_LIB)
option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF)
@@ -275,6 +278,8 @@ set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
link_libraries(Threads::Threads)
endif() # NOT CK_BUILD_HOST_LIB
## C++
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
@@ -291,6 +296,8 @@ if(USE_GLIBCXX_ASSERTIONS)
add_compile_options(-Wp,-D_GLIBCXX_ASSERTIONS)
endif()
if (NOT CK_BUILD_HOST_LIB)
## HIP
set(CMAKE_HIP_PLATFORM amd)
set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER})
@@ -346,6 +353,8 @@ else()
add_compile_definitions(__HIP_PLATFORM_HCC__=1)
endif()
endif() # NOT CK_BUILD_HOST_LIB
## tidy
include(EnableCompilerWarnings)
set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
@@ -499,6 +508,8 @@ include_directories(BEFORE
${HIP_INCLUDE_DIRS}
)
if (NOT CK_BUILD_HOST_LIB)
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
@@ -506,6 +517,8 @@ if(BUILD_DEV)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
endif() # NOT CK_BUILD_HOST_LIB
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
add_compile_options(-fcolor-diagnostics)
endif()
@@ -515,6 +528,8 @@ endif()
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
if (NOT CK_BUILD_HOST_LIB)
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp")
file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*)
set(CK_DEVICE_INSTANCES)
@@ -590,6 +605,18 @@ if(NOT DEFINED PROFILER_ONLY AND (GPU_TARGETS MATCHES "gfx9" OR DEFINED INSTANCE
add_subdirectory(codegen)
endif()
else() # NOT CK_BUILD_HOST_LIB
if(GPU_TARGETS MATCHES "gfx9")
rocm_package_setup_component(ck_host
LIBRARY_NAME composablekernel
PACKAGE_NAME ck_host
)
add_subdirectory(codegen)
endif()
endif() # NOT CK_BUILD_HOST_LIB
#Create an interface target for the include only files and call it "composablekernels"
include(CMakePackageConfigHelpers)
@@ -627,4 +654,4 @@ rocm_create_package(
MAINTAINER "MIOpen Kernels Dev Team <dl.MIOpen@amd.com>"
LDCONFIG
HEADER_ONLY
)
)

View File

@@ -1,6 +1,6 @@
@PACKAGE_INIT@
set(_composable_kernel_supported_components device_other_operations device_gemm_operations device_conv_operations device_mha_operations device_contraction_operations device_reduction_operations utility)
set(_composable_kernel_supported_components device_other_operations device_gemm_operations device_conv_operations device_mha_operations device_contraction_operations device_reduction_operations utility ck_host)
foreach(_comp ${composable_kernel_FIND_COMPONENTS})
if(NOT _comp IN_LIST _composable_kernel_supported_components)

View File

@@ -31,12 +31,21 @@ file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
##message(STATUS "SOURCE_FILES: ${SOURCES}")
# TODO: Use object library
add_library(ck_host STATIC ${SOURCES})
target_link_libraries(ck_host PRIVATE ck_headers)
add_library(composable_kernel::ck_host ALIAS ck_host)
set_target_properties(ck_host PROPERTIES
LINKER_LANGUAGE CXX
POSITION_INDEPENDENT_CODE ON)
target_include_directories(ck_host SYSTEM PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
# $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/library/src/jit_library/solution_instances>
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/solution_instances>
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/embed/ck_headers/include>
)
target_link_libraries(ck_host PRIVATE $<BUILD_INTERFACE:ck_headers>)
target_include_directories(ck_host PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
)
@@ -45,9 +54,18 @@ add_executable(ck-template-driver driver/main.cpp)
target_link_libraries(ck-template-driver ck_host)
rocm_install(
TARGETS ck_host ck_headers
TARGETS ck_host
EXPORT ck_hostTargets
)
rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
add_subdirectory(test)
rocm_install(
EXPORT ck_hostTargets
FILE composable_kernelck_hostTargets.cmake
NAMESPACE composable_kernel::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
)
if (NOT CK_BUILD_HOST_LIB)
add_subdirectory(test)
endif()

View File

@@ -0,0 +1,58 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
#include "ck/host/operation/gemm.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// defines all values need for an instance of fwd conv
struct Operation_Xdl_CShuffle
{
// returns a vector of instances, only given fusion operators: will use default problem spec
static std::vector<std::vector<Operation_Xdl_CShuffle>>
CreateOperations(const std::string& prologue, const std::string& epilogue);
// returns a vector of instances, given a problem spec and fusion operators
static std::vector<Operation_Xdl_CShuffle>
CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
TensorDesc A{};
TensorDesc B{};
TensorDesc B1{};
TensorDesc C{};
std::string a_elem_op = PassThrough;
std::string b_elem_op = PassThrough;
std::string b1_elem_op = PassThrough;
std::string c_elem_op = PassThrough;
std::string acc_elem_op = Scale;
std::string prologue = "";
std::string epilogue = "";
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
// tuning parameters
operation::TileDescGemmSoftmaxGemm tile_desc{};
operation::BlockTransferDesc a_block_transfer{};
operation::BlockTransferDesc b0_block_transfer{};
operation::BlockTransferDesc b1_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
bool mask_out_upper_triangle = false;
// functions to update fusion operators if provided
void update_prologue(const std::string& prologue);
void update_epilogue(const std::string& epilogue);
/**constexpr**/ bool IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_);
// returns a templated instance
Solution ToSolution() const;
};
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck

View File

@@ -0,0 +1,47 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <string>
#include "ck/host/types.hpp"
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// defines the problem specification for a GEMM operation
struct Problem
{
std::size_t M = 0;
std::size_t N = 0;
std::size_t K = 0;
std::size_t O = 0;
bool TransA = false;
bool TransB = false;
bool TransB1 = false;
bool TransC = false;
DataType ADataType = DataType::Half;
DataType BDataType = DataType::Half;
DataType B1DataType = DataType::Half;
DataType CDataType = DataType::Half;
std::string AElementOp = PassThrough;
std::string BElementOp = PassThrough;
std::string B1ElementOp = PassThrough;
std::string CElementOp = PassThrough;
std::string AccElementOp = Scale;
// returns the correct device op file for the operation
std::string GetIncludeHeader() const;
// returns a list of instances based on the problem spec and provided fusion operations
std::vector<Solution> GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const;
};
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck

View File

@@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle
operation::BlockTransferDesc b_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
LoopScheduler loop_scheduler{};
PipelineVersion pipeline_version{};
// functions to update fusion operators if provided
void update_prologue(const std::string& prologue);

View File

@@ -23,6 +23,26 @@ struct TileDesc
int n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
};
struct TileDescGemmSoftmaxGemm
{
int block_size = 0;
int gemm01_m_per_block = 0;
int gemm0_n_per_block = 0;
int gemm0_k_per_block = 0;
int gemm1_n_per_block = 0;
int gemm1_k_per_block = 0;
int ak1 = 0;
int bk1 = 0;
int b1k1 = 0;
int m_per_XDL = 0;
int n_per_XDL = 0;
int gemm0_m_Xdl_per_wave = 0;
int gemm0_n_Xdl_per_wave = 0;
int gemm1_n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
};
struct BlockTransferDesc
{
std::string thread_cluster_length = "";

View File

@@ -66,6 +66,20 @@ enum class GemmType
};
std::string ToString(GemmType gt);
enum class LoopScheduler
{
Default,
Interwave,
};
std::string ToString(LoopScheduler ls);
enum class PipelineVersion
{
v1,
v2
};
std::string ToString(PipelineVersion pv);
struct TensorDesc
{
DataType element;
@@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...});
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
constexpr const char* Scale = "ck::tensor_operation::element_wise::Scale";
} // namespace host
} // namespace ck

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/utils.hpp"
#include <algorithm>
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// return the relevant device op file based on the operation
std::string Problem::GetIncludeHeader() const
{
return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp";
}
// returns templated instances when provided with a problem specification
std::vector<Solution> Problem::GetSolutions(const std::string& arch,
const std::string& prologue,
const std::string& epilogue) const
{
if(get_xdlop_archs().count(arch) == 0)
return {};
auto ops = ck::host::device_batched_gemm_softmax_gemm::Operation_Xdl_CShuffle::CreateOperations(
*this, prologue, epilogue); // obtains vector of instances
std::vector<Solution> result;
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
return op.ToSolution(); // template instance with correct values
});
return result;
}
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck

View File

@@ -0,0 +1,412 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include <cassert>
namespace ck {
namespace host {
namespace device_batched_gemm_softmax_gemm {
// calculate appropriate Gemm Specification based on input tensor dimensions
std::string GetGemmSpec(const std::size_t m,
const std::size_t n,
const std::size_t k,
const std::size_t n1,
const std::size_t m_per_block,
const std::size_t n_per_block,
const std::size_t k_per_block,
const std::size_t n1_per_block)
{
std::string spec = "";
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
spec += "M";
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
spec += "N";
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
spec += "K";
if(integer_divide_ceil(n1, n1_per_block) * n1_per_block - n1 != 0)
spec += "O";
if(spec == "")
return "ck::tensor_operation::device::GemmSpecialization::Default";
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
}
// function to update prologue/epilogue with user provided operation
void Operation_Xdl_CShuffle::update_prologue(const std::string& pro)
{
if(!prologue.empty())
{
this->prologue = pro;
// TODO
// this->cde_elem_op = "CDEElementOp";
}
else
{
this->prologue = "";
}
}
void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
{
if(!epilogue.empty())
{
this->epilogue = epi;
// TODO
// this->cde_elem_op = "CDEElementOp";
}
else
{
this->epilogue = "";
}
}
// accounts for all possible combinations of Row/Col major
static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
const Problem& prob, const std::string& prologue, const std::string& epilogue)
{
std::vector<Operation_Xdl_CShuffle> result;
std::vector<operation::TileDescGemmSoftmaxGemm> tile_descriptions = {
// clang-format off
// Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| NumGemmK|
// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch|
// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage|
// | | | | | | | | | | | Wave| Wave| Wave| |
{ 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1},
{ 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1},
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1},
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1},
{ 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
{ 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
{ 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
{ 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
{ 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
{ 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
{ 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
// Padded fallback kernel
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
{ 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1},
// Irregular k
{ 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1},
{ 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1},
{ 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1},
{ 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1},
{ 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1},
{ 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1},
// clang-format on
};
const std::vector<operation::BlockTransferDesc> a_block_descriptions = {
// clang-format off
// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
// Padded fallback kernel
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
// Irregular k
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
// clang-format on
};
const std::vector<operation::BlockTransferDesc> b1_block_descriptions = {
// clang-format off
// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds|
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
// Padded fallback kernel
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
// Irregular k
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
// clang-format on
};
std::vector<operation::CShuffleDesc> cshuffle_descriptions = {
// clang-format off
// CShuffle| CShuffle|
// MXdlPerWave| NXdlPerWave|
// PerShuffle| PerShuffle|
// | |
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 8},
{ 1, 4},
{ 1, 8},
{ 1, 4},
// Padded fallback kernel
{ 1, 2},
{ 1, 2},
// Irregular k
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
{ 1, 2},
// clang-format on
};
std::vector<operation::CBlockTransferDesc> c_block_descriptions = {
// clang-format off
// CBlockTransferClusterLengths| CBlockTransfer
// _MBlock_MWaveMPerXdl| ScalarPerVector
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
// |
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 16, 1,16>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 16, 1,16>, 8},
{ S<1, 32, 1, 8>, 8},
// Padded fallback kernel
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
// Irregular k
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
// clang-format on
};
assert(tile_descriptions.size() == a_block_descriptions.size());
assert(tile_descriptions.size() == b1_block_descriptions.size());
assert(tile_descriptions.size() == cshuffle_descriptions.size());
assert(tile_descriptions.size() == c_block_descriptions.size());
// Put all values together into a single operation > store into the result vector
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
{
Operation_Xdl_CShuffle x;
x.tile_desc = tile_descriptions[i];
x.a_block_transfer = a_block_descriptions[i];
x.b0_block_transfer = a_block_descriptions[i]; // b0 same as a
x.b1_block_transfer = b1_block_descriptions[i];
x.cshuffle = cshuffle_descriptions[i];
x.c_block_transfer = c_block_descriptions[i];
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
x.B1 = TensorDesc{prob.B1DataType, ToLayout(prob.TransB1)};
x.C = TensorDesc{prob.CDataType, ToLayout(prob.TransC)};
x.a_elem_op = prob.AElementOp;
x.b_elem_op = prob.BElementOp;
x.b1_elem_op = prob.B1ElementOp;
x.c_elem_op = prob.CElementOp;
x.acc_elem_op = prob.AccElementOp;
x.gemm_specialization = GetGemmSpec(prob.M,
prob.N,
prob.K,
prob.O,
x.tile_desc.gemm01_m_per_block,
x.tile_desc.gemm0_n_per_block,
x.tile_desc.gemm0_k_per_block,
x.tile_desc.gemm1_n_per_block);
x.update_prologue(prologue);
x.update_epilogue(epilogue);
x.mask_out_upper_triangle = true;
result.push_back(x);
x.mask_out_upper_triangle = false;
result.push_back(x);
}
return result;
}
// set up instances when not provided with a problem specification, use default operation values and
// all possible layout combinations
std::vector<std::vector<Operation_Xdl_CShuffle>>
Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue)
{
Problem prob;
prob.TransA = false;
prob.TransB = true;
prob.TransB1 = false;
prob.TransC = false;
return {CreateOperations(prob, prologue, epilogue)};
}
static const char* const DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate =
"ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<${LayoutA}, "
"${LayoutB0}, ${LayoutB1}, ${LayoutC}, ${ADataType}, ${B0DataType}, ${B1DataType}, "
"${CDataType}, ${AccDataType}, ${CShuffleDataType}, ${AElementwiseOperation}, "
"${B0ElementwiseOperation}, ${Acc0ElementwiseOperation}, ${B1ElementwiseOperation}, "
"${CElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, "
"${Gemm01MPerBlock}, ${Gemm0NPerBlock}, ${Gemm0KPerBlock}, ${Gemm1NPerBlock}, "
"${Gemm1KPerBlock}, ${AK1}, ${BK1}, ${B1K1}, ${MPerXDL}, ${NPerXDL}, ${Gemm0MXdlPerWave}, "
"${Gemm0NXdlPerWave}, ${Gemm1NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, "
"${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, "
"${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, "
"${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, "
"${B0BlockTransferThreadClusterLengths_BK0_N_BK1}, "
"${B0BlockTransferThreadClusterArrangeOrder}, ${B0BlockTransferSrcAccessOrder}, "
"${B0BlockTransferSrcVectorDim}, ${B0BlockTransferSrcScalarPerVector}, "
"${B0BlockTransferDstScalarPerVector_BK1}, ${B0BlockLdsExtraN}, "
"${B1BlockTransferThreadClusterLengths_BK0_N_BK1}, "
"${B1BlockTransferThreadClusterArrangeOrder}, ${B1BlockTransferSrcAccessOrder}, "
"${B1BlockTransferSrcVectorDim}, ${B1BlockTransferSrcScalarPerVector}, "
"${B1BlockTransferDstScalarPerVector_BK1}, ${B1BlockLdsExtraN}, "
"${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl}, "
"${CBlockTransferScalarPerVector_NWaveNPerXdl}, ${MaskOutUpperTriangle}>";
// use hardcoded instances from vector of operations to substitute values into instance template
Solution Operation_Xdl_CShuffle::ToSolution() const
{
std::unordered_map<std::string, std::string> values = {
{"name",
std::to_string(this->tile_desc.block_size) + "_" +
std::to_string(this->tile_desc.gemm01_m_per_block) + "_" +
std::to_string(this->tile_desc.gemm0_n_per_block) + "_" +
std::to_string(this->tile_desc.gemm0_k_per_block) + "_" +
std::to_string(this->tile_desc.gemm1_n_per_block) + "_" +
std::to_string(this->tile_desc.gemm1_k_per_block) + "_" +
std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" +
std::to_string(this->tile_desc.b1k1) + "_" +
std::to_string(this->tile_desc.m_per_XDL) + "_" +
std::to_string(this->tile_desc.n_per_XDL) + "_" +
std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" +
std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave) + "_" +
std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)},
{"LayoutA", ToString(this->A.layout)},
{"LayoutB0", ToString(this->B.layout)},
{"LayoutB1", ToString(this->B1.layout)},
{"LayoutC", ToString(this->C.layout)},
{"ADataType", ToString(this->A.element)},
{"B0DataType", ToString(this->B.element)},
{"B1DataType", ToString(this->B1.element)},
{"CDataType", ToString(this->C.element)},
{"AccDataType", ToString(DataType::Float)},
{"CShuffleDataType", ToString(DataType::Half)},
{"AElementwiseOperation", this->a_elem_op},
{"B0ElementwiseOperation", this->b_elem_op},
{"Acc0ElementwiseOperation", this->acc_elem_op},
{"B1ElementwiseOperation", this->b1_elem_op},
{"CElementwiseOperation", this->c_elem_op},
{"GemmSpecialization", this->gemm_specialization},
{"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)},
{"BlockSize", std::to_string(this->tile_desc.block_size)},
{"Gemm01MPerBlock", std::to_string(this->tile_desc.gemm01_m_per_block)},
{"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)},
{"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)},
{"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)},
{"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)},
{"AK1", std::to_string(this->tile_desc.ak1)},
{"BK1", std::to_string(this->tile_desc.bk1)},
{"B1K1", std::to_string(this->tile_desc.b1k1)},
{"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)},
{"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)},
{"Gemm0MXdlPerWave", std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave)},
{"Gemm0NXdlPerWave", std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave)},
{"Gemm1NXdlPerWave", std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)},
{"ABlockTransferThreadClusterLengths_AK0_M_AK1",
this->a_block_transfer.thread_cluster_length},
{"ABlockTransferThreadClusterArrangeOrder",
this->a_block_transfer.thread_cluster_arrange_order},
{"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order},
{"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)},
{"ABlockTransferSrcScalarPerVector",
std::to_string(this->a_block_transfer.src_scalar_per_vector)},
{"ABlockTransferDstScalarPerVector_AK1",
std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)},
{"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)},
{"B0BlockTransferThreadClusterLengths_BK0_N_BK1",
this->b0_block_transfer.thread_cluster_length},
{"B0BlockTransferThreadClusterArrangeOrder",
this->b0_block_transfer.thread_cluster_arrange_order},
{"B0BlockTransferSrcAccessOrder", this->b0_block_transfer.src_access_order},
{"B0BlockTransferSrcVectorDim", std::to_string(this->b0_block_transfer.src_vec_dim)},
{"B0BlockTransferSrcScalarPerVector",
std::to_string(this->b0_block_transfer.src_scalar_per_vector)},
{"B0BlockTransferDstScalarPerVector_BK1",
std::to_string(this->b0_block_transfer.dst_scalar_per_vector_k1)},
{"B0BlockLdsExtraN", std::to_string(this->b0_block_transfer.lds_add_extra_dim)},
{"B1BlockTransferThreadClusterLengths_BK0_N_BK1",
this->b1_block_transfer.thread_cluster_length},
{"B1BlockTransferThreadClusterArrangeOrder",
this->b1_block_transfer.thread_cluster_arrange_order},
{"B1BlockTransferSrcAccessOrder", this->b1_block_transfer.src_access_order},
{"B1BlockTransferSrcVectorDim", std::to_string(this->b1_block_transfer.src_vec_dim)},
{"B1BlockTransferSrcScalarPerVector",
std::to_string(this->b1_block_transfer.src_scalar_per_vector)},
{"B1BlockTransferDstScalarPerVector_BK1",
std::to_string(this->b1_block_transfer.dst_scalar_per_vector_k1)},
{"B1BlockLdsExtraN", std::to_string(this->b1_block_transfer.lds_add_extra_dim)},
{"CShuffleMXdlPerWavePerShuffle",
std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)},
{"CShuffleNXdlPerWavePerShuffle",
std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)},
{"CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl",
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
{"CBlockTransferScalarPerVector_NWaveNPerXdl",
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
{"MaskOutUpperTriangle", std::to_string(this->mask_out_upper_triangle)},
};
return Solution{InterpolateString(DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate, values),
std::move(values)};
}
} // namespace device_batched_gemm_softmax_gemm
} // namespace host
} // namespace ck

View File

@@ -62,6 +62,13 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
// accounts for all possible combinations of Row/Col major
static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1,
// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
@@ -83,6 +90,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1},
{ 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1},
{ 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1},
// Irregular tile
{ 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1},
// clang-format on
};
@@ -100,6 +109,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
// clang-format on
};
@@ -109,15 +120,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// clang-format on
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
};
std::vector<operation::BlockTransferDesc> b_block_descriptions_rowmajor = {
@@ -134,6 +147,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// clang-format on
};
@@ -151,6 +166,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
// Irregular tile
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
// clang-format on
};
@@ -167,6 +184,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ 1, 1},
{ 1, 1},
{ 1, 1},
{ 1, 1},
{ 1, 1},
// clang-format on
};
@@ -185,6 +203,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
{ S<1, 16, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
// Irregular tile
{ S<1, 16, 1, 4>, 1},
// clang-format on
};
@@ -199,33 +219,44 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
assert(tile_descriptions.size() == cshuffle_descriptions.size());
assert(tile_descriptions.size() == c_block_descriptions.size());
// Put all values together into a single operation > store into the result vector
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
const std::vector<std::tuple<LoopScheduler, PipelineVersion>> scheduler_pipeline_descriptions =
{
{LoopScheduler::Default, PipelineVersion::v1},
{LoopScheduler::Interwave, PipelineVersion::v1},
{LoopScheduler::Default, PipelineVersion::v2},
};
for(auto [loop_scheduler, pipeline_version] : scheduler_pipeline_descriptions)
{
Operation_Xdl_CShuffle x;
x.tile_desc = tile_descriptions[i];
x.a_block_transfer = a_block_descriptions[i];
x.b_block_transfer = b_block_descriptions[i];
x.cshuffle = cshuffle_descriptions[i];
x.c_block_transfer = c_block_descriptions[i];
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)};
x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) {
return TensorDesc{dt, ToLayout(trans)};
});
x.a_elem_op = prob.AElementOp;
x.b_elem_op = prob.BElementOp;
x.cde_elem_op = prob.CDEElementOp;
x.gemm_specialization = GetGemmSpec(prob.M,
prob.N,
prob.K,
x.tile_desc.m_per_block,
x.tile_desc.n_per_block,
x.tile_desc.k_per_block);
x.update_prologue(prologue);
x.update_epilogue(epilogue);
result.push_back(x);
// Put all values together into a single operation > store into the result vector
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
{
Operation_Xdl_CShuffle x;
x.tile_desc = tile_descriptions[i];
x.a_block_transfer = a_block_descriptions[i];
x.b_block_transfer = b_block_descriptions[i];
x.cshuffle = cshuffle_descriptions[i];
x.c_block_transfer = c_block_descriptions[i];
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)};
x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) {
return TensorDesc{dt, ToLayout(trans)};
});
x.a_elem_op = prob.AElementOp;
x.b_elem_op = prob.BElementOp;
x.cde_elem_op = prob.CDEElementOp;
x.gemm_specialization = GetGemmSpec(prob.M,
prob.N,
prob.K,
x.tile_desc.m_per_block,
x.tile_desc.n_per_block,
x.tile_desc.k_per_block);
x.loop_scheduler = loop_scheduler;
x.pipeline_version = pipeline_version;
x.update_prologue(prologue);
x.update_epilogue(epilogue);
result.push_back(x);
}
}
return result;
}
@@ -263,7 +294,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
"${CDEBlockTransferScalarPerVector_NPerBlock}>";
"${CDEBlockTransferScalarPerVector_NPerBlock}, ${LoopScheduler}, ${PipelineVersion}>";
// use hardcoded instances from vector of operations to substitute values into instance template
Solution Operation_Xdl_CShuffle::ToSolution() const
@@ -336,6 +367,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
{"CDEBlockTransferScalarPerVector_NPerBlock",
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
{"LoopScheduler", ToString(this->loop_scheduler)},
{"PipelineVersion", ToString(this->pipeline_version)},
};
return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values),

View File

@@ -56,6 +56,26 @@ std::string ToString(GemmType gt)
throw std::runtime_error("Incorrect gemm type");
}
std::string ToString(LoopScheduler ls)
{
switch(ls)
{
case LoopScheduler::Default: return "ck::LoopScheduler::Default";
case LoopScheduler::Interwave: return "ck::LoopScheduler::Interwave";
}
throw std::runtime_error("Incorrect LoopScheduler type");
}
std::string ToString(PipelineVersion pv)
{
switch(pv)
{
case PipelineVersion::v1: return "ck::PipelineVersion::v1";
case PipelineVersion::v2: return "ck::PipelineVersion::v2";
}
throw std::runtime_error("Incorrect PipelineVersion type");
}
std::string SequenceStr(const std::vector<int>& v)
{
return "ck::Sequence<" +

View File

@@ -15,7 +15,8 @@ std::vector<rtc::src_file> get_headers_for_test()
auto hs = ck::host::GetHeaders();
std::transform(
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
return {p.first, p.second};
std::string sec(p.second.begin(), p.second.end());
return {p.first, sec};
});
return result;
}

View File

@@ -1,5 +1,7 @@
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
@@ -15,13 +17,59 @@
using half = _Float16;
// using half = __fp16;
// NOLINTNEXTLINE
const char* const disable_warning_pragma = R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__";
template <class P>
std::string ck_disable_warnings(P p)
{
return ck::host::InterpolateString(disable_warning_pragma,
{{"content", std::string{p.data(), p.size()}}});
}
static std::unordered_map<std::string, std::string> create_ck_header_strings()
{
std::unordered_map<std::string, std::string> result;
auto ck_headers = ck::host::GetHeaders();
std::transform(
ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto& p) {
return std::pair<std::string, std::string>(p.first, ck_disable_warnings(p.second));
});
return result;
}
static std::vector<rtc::src_file> create_ck_headers()
{
static const auto& header_strings = create_ck_header_strings();
std::vector<rtc::src_file> srcs;
std::transform(
header_strings.begin(), header_strings.end(), std::back_inserter(srcs), [&](auto& p) -> rtc::src_file {
std::string sec(p.second.begin(), p.second.end());
return {p.first, sec};
});
return srcs;
}
static inline const std::vector<rtc::src_file>& ck_headers()
{
static const auto& headers = create_ck_headers();
return headers;
}
std::vector<rtc::src_file> get_headers_for_test()
{
std::vector<rtc::src_file> result;
auto hs = ck::host::GetHeaders();
std::transform(
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
return {p.first, p.second};
std::string sec(p.second.begin(), p.second.end());
return {p.first, sec};
});
return result;
}
@@ -130,10 +178,13 @@ const std::string gemm_compile_check = R"__ck__(
extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) {
using G = ${template};
constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})),
ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})),
ck::make_tuple(),
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n})));
constexpr auto desc =
G::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m},
${k})),
ck::make_naive_tensor_descriptor(ck::make_tuple(${n},
${k}), ck::make_tuple(1, ${n})), ck::make_tuple(),
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m},
${n})));
static_assert(desc.IsValid(), "Invalid ck gemm.");
@@ -163,23 +214,32 @@ TEST_CASE(test_problem_kernel)
std::string epilogue = "";
std::string prologue = "";
for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue))
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
std::cout << "Num solutions: " << solutions.size() << std::endl;
for(auto i = 0; i < solutions.size(); ++i)
{
auto src = ck::host::InterpolateString(gemm_compile_check,
{{"include", prob.GetIncludeHeader()},
{"template", solution.ToTemplateString()},
{"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)}});
auto srcs = get_headers_for_test();
srcs.push_back({"main.cpp", src});
rtc::compile_options options;
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
auto&& solution = solutions[i];
auto src = ck::host::InterpolateString(gemm_compile_check,
{{"include", prob.GetIncludeHeader()},
{"template", solution.ToTemplateString()},
{"m", std::to_string(prob.M)},
{"n", std::to_string(prob.N)},
{"k", std::to_string(prob.K)}});
// auto srcs = get_headers_for_test();
// srcs.push_back({"main.cpp", src});
// rtc::compile_options options;
// options.kernel_name = "f";
rtc::hip_compile_options options;
options.kernel_name = "f";
auto k = rtc::compile_kernel(srcs, options);
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
options.additional_src_files = ck_headers();
// auto k = rtc::compile_kernel(srcs, options);
std::cout << src << std::endl;
auto k = rtc::compile_hip_code_object(src, options);
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
ck::host::integer_divide_ceil(prob.N, n_per_block);
k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data());
@@ -187,4 +247,34 @@ TEST_CASE(test_problem_kernel)
}
}
TEST_CASE(test_gemm_softmax_gemm)
{
ck::host::device_batched_gemm_softmax_gemm::Problem prob;
prob.TransA = false;
prob.TransB = true;
prob.TransB1 = false;
prob.TransC = false;
prob.M = 1024;
prob.N = 1024;
prob.K = 1024;
prob.O = 1024;
check_all<half> check;
auto a = to_gpu(generate_buffer<half>(1024 * 1024, 0));
auto b = to_gpu(generate_buffer<half>(1024 * 1024, 1));
auto b1 = to_gpu(generate_buffer<half>(1024 * 1024, 2));
auto c = to_gpu(generate_buffer<half>(1024 * 1024, 3));
std::string epilogue = "";
std::string prologue = "";
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
std::cout << "Num solutions: " << solutions.size() << std::endl;
for(auto i = 0; i < solutions.size(); ++i) {
std::cout << "Solution " << i << std::endl;
std::cout << solutions[i].ToTemplateString() << std::endl;
std::cout << std::endl;
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }

View File

@@ -4,6 +4,7 @@
#include <rtc/kernel.hpp>
#include <ck/filesystem.hpp>
#include <string>
#include <functional>
namespace rtc {
@@ -19,9 +20,36 @@ struct compile_options
std::string kernel_name = "main";
};
struct hip_compile_options
{
std::size_t global;
std::size_t local;
std::string kernel_name = "kernel";
std::string params = "";
std::vector<src_file> additional_src_files = {};
/**
* @brief Set the launch parameters but allow v to override the values
*
* @param v A value class which can have a "global" and/or "local" keys to override the default
* global and local
* @param compute_global A function used to compute the global based on the local
* @param default_local The defaul local to use if its missing from the v parameter
*/
void set_launch_params(const std::function<std::size_t(std::size_t local)>& compute_global,
std::size_t default_local = 1024);
void set_launch_params(std::size_t default_global, std::size_t default_local = 1024)
{
set_launch_params([=](auto) { return default_global; }, default_local);
}
};
kernel compile_kernel(const std::vector<src_file>& src,
compile_options options = compile_options{});
kernel compile_hip_code_object(const std::string& content, hip_compile_options options);
} // namespace rtc
#endif

View File

@@ -4,6 +4,7 @@
#include <hip/hip_runtime_api.h>
#include <memory>
#include <string>
#include <stdexcept>
namespace rtc {

View File

@@ -1,10 +1,13 @@
#include "rtc/hip.hpp"
#include <rtc/compile_kernel.hpp>
#include <hip/hiprtc.h>
#include <rtc/tmp_dir.hpp>
#include <stdexcept>
#include <iostream>
#include <fstream>
#include <cassert>
#include <deque>
#include <numeric>
namespace rtc {
@@ -100,4 +103,345 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
return kernel{obj.data(), options.kernel_name};
}
struct hiprtc_src_file
{
hiprtc_src_file() = default;
hiprtc_src_file(const src_file& s) : path(s.path.string()), content(s.content) {}
std::string path;
std::string content;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.path, "path"), f(self.content, "content"));
}
};
std::string hiprtc_error(hiprtcResult err, const std::string& msg)
{
return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg));
}
void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::string& ctx)
{
if(err != HIPRTC_SUCCESS)
throw std::runtime_error(hiprtc_error(err, msg));
}
// NOLINTNEXTLINE
#define MIGRAPHX_HIPRTC(...) \
hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet")
#define MIGRAPHX_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg))
template <class F, F f> // NOLINT
struct manage_deleter
{
template <class T>
void operator()(T* x) const
{
if(x != nullptr)
{
(void)f(x);
}
}
};
template <class T, class F, F f> // NOLINT
using manage_ptr = std::unique_ptr<T, manage_deleter<F, f>>;
#define MIGRAPHX_MANAGE_PTR(T, F) manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
// Workaround hiprtc's broken API
void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); }
using hiprtc_program_ptr = MIGRAPHX_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy);
template <class... Ts>
hiprtc_program_ptr hiprtc_program_create(Ts... xs)
{
hiprtcProgram prog = nullptr;
auto result = hiprtcCreateProgram(&prog, xs...);
hiprtc_program_ptr p{prog};
if(result != HIPRTC_SUCCESS)
MIGRAPHX_HIPRTC_THROW(result, "Create program failed.");
return p;
}
bool starts_with(const std::string& value, const std::string& prefix)
{
if(prefix.size() > value.size())
return false;
else
return std::equal(prefix.begin(), prefix.end(), value.begin());
}
bool ends_with(const std::string& value, const std::string& suffix)
{
if(suffix.size() > value.size())
return false;
else
return std::equal(suffix.rbegin(), suffix.rend(), value.rbegin());
}
std::vector<std::string> split_string(const std::string& s, char delim)
{
std::vector<std::string> elems;
std::stringstream ss(s + delim);
std::string item;
while(std::getline(ss, item, delim))
{
elems.push_back(item);
}
return elems;
}
template <class Strings>
inline std::string join_strings(Strings strings, const std::string& delim)
{
auto it = strings.begin();
if(it == strings.end())
return "";
auto nit = std::next(it);
return std::accumulate(nit, strings.end(), *it, [&](std::string x, std::string y) {
return std::move(x) + delim + std::move(y);
});
}
struct hiprtc_program
{
struct string_array
{
std::deque<std::string> strings{};
std::vector<const char*> c_strs{};
string_array() {}
string_array(const string_array&) = delete;
std::size_t size() const { return strings.size(); }
const char** data() { return c_strs.data(); }
void push_back(std::string s)
{
strings.push_back(std::move(s));
c_strs.push_back(strings.back().c_str());
}
};
hiprtc_program_ptr prog = nullptr;
string_array headers{};
string_array include_names{};
std::string cpp_src = "";
std::string cpp_name = "";
hiprtc_program(const std::string& src, const std::string& name = "main.cpp")
: cpp_src(src), cpp_name(name)
{
create_program();
}
hiprtc_program(std::vector<src_file> srcs)
{
for(auto&& src : srcs)
{
if(ends_with(src.path, ".cpp"))
{
cpp_src = std::move(src.content);
cpp_name = std::move(src.path);
}
else
{
headers.push_back(std::move(src.content));
include_names.push_back(std::move(src.path));
}
}
create_program();
}
void create_program()
{
assert(not cpp_src.empty());
assert(not cpp_name.empty());
assert(headers.size() == include_names.size());
prog = hiprtc_program_create(cpp_src.c_str(),
cpp_name.c_str(),
headers.size(),
headers.data(),
include_names.data());
}
void compile(const std::vector<std::string>& options, bool quiet = false) const
{
// if(enabled(MIGRAPHX_TRACE_HIPRTC{}))
// std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl;
std::vector<const char*> c_options;
std::transform(options.begin(),
options.end(),
std::back_inserter(c_options),
[](const std::string& s) { return s.c_str(); });
std::cout << "BEFORE HIPRTC COMPILE" << std::endl;
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
auto prog_log = log();
if(not prog_log.empty() and not quiet)
{
std::cerr << prog_log << std::endl;
}
if(result != HIPRTC_SUCCESS)
throw std::runtime_error("Compilation failed.");
}
std::string log() const
{
std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n));
if(n == 0)
return {};
std::string buffer(n, '\0');
MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
assert(buffer.back() != 0);
return buffer;
}
std::vector<char> get_code_obj() const
{
std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n));
std::vector<char> buffer(n);
MIGRAPHX_HIPRTC(hiprtcGetCode(prog.get(), buffer.data()));
return buffer;
}
};
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<src_file> srcs,
const std::string& params,
const std::string& arch)
{
hiprtc_program prog(std::move(srcs));
auto options = split_string(params, ' ');
options.push_back("-DMIGRAPHX_USE_HIPRTC=1");
if(true)
{
options.push_back("-DMIGRAPHX_HAS_DPP=0");
options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1");
options.push_back("-Wno-reserved-identifier");
options.push_back("-Wno-unused-parameter");
options.push_back("-Wno-gnu-line-marker");
options.push_back("-Wno-old-style-cast");
}
if(true)
options.push_back("-DMIGRAPHX_DEBUG");
if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
return starts_with(s, "--std=") or starts_with(s, "-std=");
}))
options.push_back("-std=c++17");
options.push_back("-fno-gpu-rdc");
options.push_back("-O3");
options.push_back("-Wno-cuda-compat");
options.push_back("--offload-arch=" + arch);
prog.compile(options);
return {prog.get_code_obj()};
}
bool hip_has_flags(const std::vector<std::string>& flags)
{
hiprtc_program prog{" "};
try
{
prog.compile(flags, true);
return true;
}
catch(...)
{
return false;
}
}
bool hip_accept_non_uniform_wg()
{
static bool non_uniform_wg = hip_has_flags({"-fno-offload-uniform-block"});
return non_uniform_wg;
}
static std::vector<std::string> get_compiler_warnings()
{
std::vector<std::string> warnings = {
"-Weverything",
"-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic",
"-Wno-conversion",
"-Wno-double-promotion",
"-Wno-exit-time-destructors",
"-Wno-extra-semi",
"-Wno-extra-semi-stmt",
"-Wno-float-conversion",
"-Wno-gnu-anonymous-struct",
"-Wno-gnu-zero-variadic-macro-arguments",
"-Wno-missing-prototypes",
"-Wno-nested-anon-types",
"-Wno-padded",
"-Wno-shorten-64-to-32",
"-Wno-sign-conversion",
"-Wno-sign-compare",
"-Wno-unused-command-line-argument",
"-Wno-weak-vtables",
"-Wno-c99-extensions",
};
if(hip_has_flags({"-Werror", "-Wunsafe-buffer-usage"}))
warnings.push_back("-Wno-unsafe-buffer-usage");
return warnings;
}
const std::vector<std::string>& compiler_warnings()
{
static std::vector<std::string> warnings = get_compiler_warnings();
return warnings;
}
kernel compile_hip_code_object(const std::string& content, hip_compile_options options)
{
assert(options.global > 0);
assert(options.local > 0);
// assert(not options.inputs.empty());
// assert(options.inputs.size() == options.virtual_inputs.size() or
// options.virtual_inputs.empty());
std::vector<src_file> srcs = options.additional_src_files;
// Neko sranje
// static auto kernels{::migraphx_kernels()};
// std::transform(
// kernels.begin(),
// kernels.end(),
// std::back_inserter(srcs),
// [](const std::pair<std::string_view, std::string_view>& elem) { return src_file{elem};
// });
srcs.emplace_back("main.cpp", content);
for (auto src : srcs) {
std::cout << src.path << std::endl;
}
// auto args_hpp =
// generate_args_hpp(options.virtual_inputs.empty() ? options.inputs :
// options.virtual_inputs);
// srcs.emplace_back("args.hpp", args_hpp);
if(options.global % options.local != 0 and hip_accept_non_uniform_wg())
options.params += " -fno-offload-uniform-block";
else
assert(options.global % options.local == 0);
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " " + join_strings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror";
auto cos = compile_hip_src_with_hiprtc(srcs, options.params, get_device_name());
if(cos.size() != 1)
std::runtime_error("No code object");
auto& obj = cos.front();
return kernel{obj.data(), options.kernel_name};
}
} // namespace rtc

View File

@@ -4,16 +4,12 @@
#pragma once
#include "ck/config.h"
#include "ck/utility/env.hpp"
#ifndef __HIPCC_RTC__
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#endif
// environment variable to enable logging:
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#endif
// to do: add various levels of logging with CK_LOG_LEVEL

View File

@@ -3,6 +3,7 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <string>
#include <map>
#include <hip/hip_runtime.h>
@@ -96,3 +97,4 @@ inline bool is_gfx12_supported()
}
} // namespace ck
#endif

View File

@@ -2,7 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
#include <hip/hip_runtime.h>
#include "ck/ck.hpp"
@@ -160,3 +160,4 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
return 0;
#endif
}
#endif

View File

@@ -3,15 +3,17 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <string>
#include <sstream>
#include "ck/stream_config.hpp"
#endif
namespace ck {
namespace tensor_operation {
namespace device {
#ifndef __HIPCC_RTC__
struct BaseArgument
{
BaseArgument() = default;
@@ -36,6 +38,7 @@ struct BaseInvoker
virtual ~BaseInvoker() {}
};
#endif
struct BaseOperator
{
@@ -43,6 +46,7 @@ struct BaseOperator
BaseOperator(const BaseOperator&) = default;
BaseOperator& operator=(const BaseOperator&) = default;
#ifndef __HIPCC_RTC__
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; }
@@ -66,7 +70,7 @@ struct BaseOperator
assert(p_arg);
p_arg->p_workspace_ = p_workspace;
}
#endif
virtual ~BaseOperator() {}
};

View File

@@ -2,9 +2,10 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <vector>
#endif
#include "device_base.hpp"
@@ -28,6 +29,7 @@ template <typename ALayout,
bool MaskOutUpperTriangle> // TODO: enum for mask type
struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
{
#ifndef __HIPCC_RTC__
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b0,
@@ -53,6 +55,7 @@ struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
#endif
};
} // namespace device

View File

@@ -2,9 +2,11 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
#include <array>
#endif
#include "ck/utility/array.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
@@ -34,6 +36,7 @@ struct DeviceGemmMultipleD : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef __HIPCC_RTC__
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
@@ -51,6 +54,7 @@ struct DeviceGemmMultipleD : public BaseOperator
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
#endif
};
// GEMM:
@@ -76,6 +80,7 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef __HIPCC_RTC__
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
@@ -94,6 +99,7 @@ struct DeviceGemmMultipleDSplitK : public BaseOperator
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
#endif
};
} // namespace device

View File

@@ -28,7 +28,7 @@ enum struct GemmSpecialization
NKOPadding,
MNKOPadding,
};
#ifndef __HIPCC_RTC__
inline std::string getGemmSpecializationString(const GemmSpecialization& s)
{
switch(s)
@@ -52,6 +52,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
default: return "Unrecognized specialization!";
}
}
#endif
} // namespace device
} // namespace tensor_operation

View File

@@ -3,8 +3,12 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
@@ -15,8 +19,6 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
@@ -40,27 +42,27 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
FloatC* __restrict__ p_c_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const AccElementwiseOperation acc_element_op,
const B1ElementwiseOperation b1_element_op,
const CElementwiseOperation c_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap block_2_ctile_map,
const index_t batch_count,
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
const C0MatrixMask c0_matrix_mask)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
@@ -430,6 +432,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
matrix_padder.PadN,
MaskOutUpperTriangle>;
#ifndef __HIPCC_RTC__
// Argument
struct Argument : public BaseArgument
{
@@ -604,6 +607,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
#endif
static constexpr bool IsValidCompilationParameter()
{
@@ -611,6 +615,97 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return true;
}
static constexpr bool
IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
{
// check vector load/store
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// check vector load of A
if constexpr(is_same_v<ALayout, Row>)
{
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col>)
{
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B
if constexpr(is_same_v<BLayout, Row>)
{
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Col>)
{
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of B1
if constexpr(is_same_v<B1Layout, Row>)
{
if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<B1Layout, Col>)
{
if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// check vector load of C
if constexpr(is_same_v<CLayout, Row>)
{
if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else if constexpr(is_same_v<CLayout, Col>)
{
if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
return true;
}
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
@@ -765,8 +860,271 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return str.str();
}
#endif
template <class ADesc, class BDesc, class B1Desc, class CDesc>
struct Descriptor
{
template <class AGridDescriptor>
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc)
{
const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc);
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(
a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <class BGridDescriptor>
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc)
{
const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(
b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <class B1GridDescriptor>
static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc)
{
const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc);
const auto N = b1_grid_desc_n_k.GetLength(I0);
const auto K = b1_grid_desc_n_k.GetLength(I1);
const auto B1K0 = K / B1K1;
return transform_tensor_descriptor(
b1_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
template <class CGridDescriptor>
static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc)
{
return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc);
}
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>;
using B1GridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>;
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
ADataType, // TODO: distinguish A/B datatype
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
B1ElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
B1GridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
Gemm1NPerBlock,
Gemm1KPerBlock,
AK1,
BK1,
B1K1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
Gemm1NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
true,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
true,
BBlockLdsExtraN,
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
B1BlockTransferThreadClusterArrangeOrder,
B1BlockTransferSrcAccessOrder,
B1BlockTransferSrcVectorDim,
B1BlockTransferSrcScalarPerVector,
B1BlockTransferDstScalarPerVector_BK1,
false,
B1BlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched,
matrix_padder.PadN,
MaskOutUpperTriangle>;
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
CGridDesc_M_N c_grid_desc_m_n;
C0MatrixMask c0_matrix_mask;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_descriptor_mblock_mperblock_nblock_nperblock;
// element-wise op
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
B1ElementwiseOperation b1_element_op;
CElementwiseOperation c_element_op;
bool has_main_k_block_loop = true;
bool is_valid = false;
constexpr Descriptor(ADesc a,
BDesc b,
B1Desc b1,
CDesc c,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
B1ElementwiseOperation b1_element_op_,
CElementwiseOperation c_element_op_)
: a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)},
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)},
c_grid_descriptor_mblock_mperblock_nblock_nperblock{
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n)},
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
c0_matrix_mask{c.GetLength(I1)},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
b1_element_op{b1_element_op_},
c_element_op{c_element_op_},
is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1,
c_grid_desc_m_n,
block_2_ctile_map) and
IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
b_grid_desc_bk0_n_bk1.GetLength(I1),
a_grid_desc_ak0_m_ak1.GetLength(I0) *
a_grid_desc_ak0_m_ak1.GetLength(I2),
b1_grid_desc_bk0_n_bk1.GetLength(I1))}
{
}
constexpr bool IsValid() const { return is_valid; }
};
template <class ADesc, class BDesc, class B1Desc, class CDesc>
static constexpr auto
make_descriptor(ADesc a,
BDesc b,
B1Desc b1,
CDesc c,
AElementwiseOperation a_element_op = AElementwiseOperation{},
BElementwiseOperation b_element_op = BElementwiseOperation{},
B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
CElementwiseOperation c_element_op = CElementwiseOperation{})
{
return Descriptor<ADesc, BDesc, B1Desc, CDesc>(
a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op);
}
template <class Desc>
__device__ static void Run(const Desc& desc,
const float scale,
const ADataType* __restrict__ p_a_grid,
const ADataType* __restrict__ p_b_grid,
const ADataType* __restrict__ p_b1_grid,
CDataType* __restrict__ p_c_grid)
{
#ifndef __HIPCC_RTC__
assert(desc.is_valid);
#endif
__shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
AccElementwiseOperation acc_element_op{scale};
if(desc.has_main_k_block_loop)
{
Desc::GridwiseGemm::template Run<true>(
p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
acc_element_op,
desc.b1_element_op,
desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.b1_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.block_2_ctile_map,
desc.c0_matrix_mask);
}
else
{
Desc::GridwiseGemm::template Run<false>(
p_a_grid,
p_b_grid,
p_b1_grid,
p_c_grid,
p_shared_block,
desc.a_element_op,
desc.b_element_op,
acc_element_op,
desc.b1_element_op,
desc.c_element_op,
desc.a_grid_desc_ak0_m_ak1,
desc.b_grid_desc_bk0_n_bk1,
desc.b1_grid_desc_bk0_n_bk1,
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
desc.block_2_ctile_map,
desc.c0_matrix_mask);
}
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
} // namespace ck

View File

@@ -3,8 +3,12 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
@@ -14,8 +18,6 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
@@ -35,22 +37,22 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
@@ -225,9 +227,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
static auto MakeDsGridDescriptor_M_N(const Array<index_t, NumDTensor>& MRaws,
const Array<index_t, NumDTensor>& NRaws,
const Array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
@@ -309,6 +311,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
#ifndef __HIPCC_RTC__
// Argument
struct Argument : public BaseArgument
{
@@ -498,6 +501,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
};
#endif
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
{
// check vector load/store
@@ -578,6 +583,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return true;
}
#ifndef __HIPCC_RTC__
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
@@ -676,11 +682,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<LoopScheduler, std::string> LoopSchedToString{{LoopScheduler::Default, "Default"},
{ LoopScheduler::Interwave,
"Interwave" }};
std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
{PipelineVersion::v2, "v2"}};
{ PipelineVersion::v2,
"v2" }};
// clang-format off
str << "DeviceGemmMultipleD_Xdl_CShuffle"
@@ -709,6 +717,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return str.str();
}
#endif
template <class ADesc, class BDesc, class DsDesc, class EDesc>
struct Descriptor
@@ -847,7 +856,9 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
EDataType* __restrict__ p_e_grid)
{
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
#ifndef __HIPCC_RTC__
assert(desc.IsValid());
#endif
if(desc.has_main_k_block_loop)
{
GridwiseGemm::template Run<true>(p_a_grid,

View File

@@ -13,6 +13,7 @@ enum struct MaskingSpecialization
MaskOutUpperTriangle
};
#ifndef __HIPCC_RTC__
inline std::string getMaskingSpecializationString(const MaskingSpecialization& s)
{
switch(s)
@@ -22,6 +23,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
default: return "Unrecognized specialization!";
}
}
#endif
struct MaskDisabledPredicate
{
@@ -53,7 +55,7 @@ struct MaskOutUpperTrianglePredicate
template <typename MaskOutPredicate>
struct C0MatrixMask_impl
{
__host__ __device__ C0MatrixMask_impl(index_t NRaw)
__host__ __device__ constexpr C0MatrixMask_impl(index_t NRaw)
: NRaw_(NRaw), predicate_(MaskOutPredicate{})
{
}

View File

@@ -430,6 +430,7 @@ struct G_NDHW : public BaseTensorLayout
} // namespace convolution
#ifndef __HIPCC_RTC__
template <
typename Layout,
typename std::enable_if<std::is_base_of<BaseTensorLayout, Layout>::value, bool>::type = false>
@@ -438,6 +439,7 @@ std::ostream& operator<<(std::ostream& os, const Layout&)
os << Layout::name;
return os;
}
#endif
} // namespace tensor_layout
} // namespace ck

View File

@@ -340,8 +340,8 @@ struct Bilinear
};
template <>
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
__host__ __device__ constexpr void operator()<int8_t, int32_t, int8_t>(
int8_t& y, const int32_t& x0, const int8_t& x1) const
{
y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
beta_ * type_convert<float>(x1));

View File

@@ -466,7 +466,7 @@ struct FastGelu
template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const;
#ifndef __HIPCC_RTC__
template <>
__host__ void operator()<float, float>(float& y, const float& x) const
{
@@ -477,7 +477,7 @@ struct FastGelu
const float emu = exp(u);
y = x / (1.f + emu);
}
#endif
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template <>
__device__ void operator()<float, float>(float& y, const float& x) const

View File

@@ -7,8 +7,10 @@
#include "ck/utility/number.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#ifndef __HIPCC_RTC__
#include <limits>
#include <stdlib.h>
#endif
namespace ck {
@@ -979,7 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return std::make_tuple(N0, M0, k_split);
return ck::make_tuple(N0, M0, k_split);
}
template <typename TopIdx>
@@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t dp_for_sk_iters = k_iters_per_tile.get();
uint32_t best_sk_score =
std::numeric_limits<int>::max(); // we need to find the smallest sk iters
ck::NumericLimits<int>::Max(); // we need to find the smallest sk iters
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
tentative_sk_blocks++)
{

View File

@@ -475,9 +475,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
template <typename DsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
MakeDsGridDescriptor_M_N(const Array<index_t, NumDTensor>& MRaws,
const Array<index_t, NumDTensor>& NRaws,
const Array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
@@ -941,7 +941,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
const Array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const Block2ETileMap& block_2_etile_map)
{

View File

@@ -3,8 +3,10 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <ostream>
#endif
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
@@ -53,12 +55,15 @@ constexpr auto GridwiseGemmPipeline_Selector()
}
else
{
#ifndef __HIPCC_RTC__
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
#endif
}
}
} // namespace ck
#ifndef __HIPCC_RTC__
inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
{
switch(p)
@@ -71,3 +76,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
}
return os;
}
#endif

View File

@@ -1005,6 +1005,7 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
#ifndef __HIPCC_RTC__
template <typename T, index_t NumElemsPerThread>
__device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset,
@@ -1042,5 +1043,6 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
#endif
}
#endif
} // namespace ck

View File

@@ -7,10 +7,12 @@
#include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp"
#ifndef __HIPCC_RTC__
#include <array>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#endif
namespace ck {
namespace detail {
@@ -37,7 +39,7 @@ struct get_carrier<3>
{
using value_type = uint32_t;
std::array<std::byte, 3> bytes;
Array<ck::byte, 3> bytes;
static_assert(sizeof(bytes) <= sizeof(value_type));
// replacement of host std::copy_n()
@@ -61,22 +63,22 @@ struct get_carrier<3>
// method to trigger template substitution failure
__device__ carrier(const carrier& other) noexcept
{
copy_n(other.bytes.begin(), bytes.size(), bytes.begin());
copy_n(other.bytes.begin(), bytes.Size(), bytes.begin());
}
public:
__device__ carrier& operator=(value_type value) noexcept
{
copy_n(reinterpret_cast<const std::byte*>(&value), bytes.size(), bytes.begin());
copy_n(reinterpret_cast<const ck::byte*>(&value), bytes.Size(), bytes.begin());
return *this;
}
__device__ operator value_type() const noexcept
{
std::byte result[sizeof(value_type)];
ck::byte result[sizeof(value_type)];
copy_n(bytes.begin(), bytes.size(), result);
copy_n(bytes.begin(), bytes.Size(), result);
return *reinterpret_cast<const value_type*>(result);
}
@@ -109,8 +111,8 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
{
constexpr unsigned object_size = sizeof(int64_t);
constexpr unsigned second_part_offset = object_size / 2;
auto* const from_obj = reinterpret_cast<const std::byte*>(&value);
alignas(int64_t) std::byte to_obj[object_size];
auto* const from_obj = reinterpret_cast<const ck::byte*>(&value);
alignas(int64_t) ck::byte to_obj[object_size];
using Sgpr = uint32_t;
@@ -124,15 +126,15 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
template <
typename Object,
typename = std::enable_if_t<std::is_class_v<Object> && std::is_trivially_copyable_v<Object>>>
typename = ck::enable_if_t<ck::is_class_v<Object> && ck::is_trivially_copyable_v<Object>>>
__device__ auto amd_wave_read_first_lane(const Object& obj)
{
using Size = unsigned;
constexpr Size SgprSize = 4;
constexpr Size ObjectSize = sizeof(Object);
auto* const from_obj = reinterpret_cast<const std::byte*>(&obj);
alignas(Object) std::byte to_obj[ObjectSize];
auto* const from_obj = reinterpret_cast<const ck::byte*>(&obj);
alignas(Object) ck::byte to_obj[ObjectSize];
constexpr Size RemainedSize = ObjectSize % SgprSize;
constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize;

View File

@@ -38,6 +38,8 @@ struct Array
}
__host__ __device__ constexpr const TData* begin() const { return &mData[0]; }
__host__ __device__ constexpr const TData* end() const { return &mData[NSize]; }
__host__ __device__ constexpr TData* begin() { return &mData[0]; }
__host__ __device__ constexpr TData* end() { return &mData[NSize]; }
};
// empty Array
@@ -54,7 +56,7 @@ template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
{
using data_type = remove_cvref_t<X>;
return Array<data_type, sizeof...(Xs) + 1>{std::forward<X>(x), std::forward<Xs>(xs)...};
return Array<data_type, sizeof...(Xs) + 1>{ck::forward<X>(x), ck::forward<Xs>(xs)...};
}
// make empty array

View File

@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
__host__ __device__ constexpr auto container_concat(const Array<T, NX>& ax, const Array<T, NY>& ay)
{
return unpack2(
[&](auto&&... zs) { return make_array(std::forward<decltype(zs)>(zs)...); }, ax, ay);
[&](auto&&... zs) { return make_array(ck::forward<decltype(zs)>(zs)...); }, ax, ay);
}
template <typename... X, typename... Y>
__host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{
return unpack2(
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
[&](auto&&... zs) { return make_tuple(ck::forward<decltype(zs)>(zs)...); }, tx, ty);
}
template <typename Container>

View File

@@ -5,8 +5,25 @@
#include "ck/utility/statically_indexed_array.hpp"
#ifdef __HIPCC_RTC__
/// Definitions from <cstdint>, <cmath> conflict with
/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h.
using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;
using uint16_t = unsigned short;
using float_t = float;
#endif // __HIPCC_RTC__
namespace ck {
#ifdef __HIPCC_RTC__
using byte = unsigned char;
#else
using std::byte;
#endif
using bhalf_t = ushort;
using half_t = _Float16;
using int4_t = _BitInt(4);
@@ -1060,6 +1077,146 @@ using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using uint8x64_t = typename vector_type<uint8_t, 64>::type;
#ifdef __HIPCC_RTC__
template <typename T>
struct NumericLimits;
template <>
struct NumericLimits<int32_t>
{
__host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; }
__host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; }
__host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; }
__host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int32_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<int16_t>
{
__host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; }
__host__ __device__ static constexpr int16_t Min() noexcept { return -32768; }
__host__ __device__ static constexpr int16_t Max() noexcept { return 32767; }
__host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int16_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<int8_t>
{
__host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; }
__host__ __device__ static constexpr int8_t Min() noexcept { return -128; }
__host__ __device__ static constexpr int8_t Max() noexcept { return 127; }
__host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int8_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint32_t>
{
__host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; }
__host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint16_t>
{
__host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; }
__host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<float>
{
static constexpr unsigned int binary_min = 0x00800000;
static constexpr unsigned int binary_max = 0x7F7FFFFF;
static constexpr unsigned int binary_lowest = 0xFF7FFFFF;
static constexpr unsigned int binary_qnan = 0xFFC00001;
static constexpr unsigned int binary_inf = 0x7F8000000;
__host__ __device__ static constexpr float Min() { return bit_cast<float>(binary_min); }
__host__ __device__ static constexpr float Max() { return bit_cast<float>(binary_max); }
__host__ __device__ static constexpr float Lowest() { return bit_cast<float>(binary_lowest); }
__host__ __device__ static constexpr float QuietNaN() { return bit_cast<float>(binary_qnan); }
__host__ __device__ static constexpr float Infinity() { return bit_cast<float>(binary_inf); }
};
template <>
struct NumericLimits<half_t>
{
static constexpr unsigned short binary_min = 0x0400;
static constexpr unsigned short binary_max = 0x7BFF;
static constexpr unsigned short binary_lowest = 0xFBFF;
static constexpr unsigned short binary_qnan = 0x7FFF;
__host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); }
__host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); }
__host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
__host__ __device__ static constexpr half_t QuietNaN() { return bit_cast<half_t>(binary_qnan); }
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct NumericLimits<int4_t>
{
__host__ __device__ static constexpr int4_t Min() { return int4_t(-8); }
__host__ __device__ static constexpr int4_t Max() { return int4_t(7); }
__host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); }
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct NumericLimits<f8_t>
{
static constexpr uint8_t binary_min = 0x08; // 0b00001000
static constexpr uint8_t binary_max = 0x77; // 0b01110111
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
__host__ __device__ static constexpr f8_t Min() { return bit_cast<f8_t>(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return bit_cast<f8_t>(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return bit_cast<f8_t>(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast<f8_t>(binary_qnan); }
};
#else
template <typename T>
struct NumericLimits
{
@@ -1151,6 +1308,7 @@ struct NumericLimits<bf8_t>
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
};
#endif
template <typename T>
struct NumericUtils

View File

@@ -4,11 +4,26 @@
#pragma once
namespace ck {
#ifdef __HIPCC_RTC__
template <bool B, class T = void>
struct enable_if
{
};
template <class T>
struct enable_if<true, T>
{
using type = T;
};
template <bool B, class T = void>
using enable_if_t = typename enable_if<B, T>::type;
#else
template <bool B, typename T = void>
using enable_if = std::enable_if<B, T>;
template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type;
#endif
} // namespace ck

View File

@@ -183,3 +183,7 @@ void UpdateEnvVar(EnvVar, const std::string_view& val)
}
} // namespace ck
// environment variable to enable logging:
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)

View File

@@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y)
{
if constexpr(predicate)
{
return std::forward<X>(x);
return ck::forward<X>(x);
}
else
{
return std::forward<Y>(y);
return ck::forward<Y>(y);
}
}

View File

@@ -21,7 +21,7 @@ struct unpack_impl<Sequence<Is...>>
template <typename F, typename X>
__host__ __device__ constexpr auto operator()(F&& f, X&& x) const
{
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...);
return ck::forward<F>(f)(ck::forward<X>(x).At(Number<Is>{})...);
}
};
@@ -35,8 +35,8 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const
{
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...,
std::forward<Y>(y).At(Number<Js>{})...);
return ck::forward<F>(f)(ck::forward<X>(x).At(Number<Is>{})...,
ck::forward<Y>(y).At(Number<Js>{})...);
}
};
@@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x)
{
using X_ = remove_reference_t<X>;
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x));
ck::forward<F>(f), ck::forward<X>(x));
}
// TODO: properly implement unpack that takes any number of containers
@@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
using Y_ = remove_reference_t<Y>;
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type,
typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
ck::forward<F>(f), ck::forward<X>(x), ck::forward<Y>(y));
}
} // namespace ck

View File

@@ -9,14 +9,14 @@ namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector
{
using value_t = std::false_type;
using value_t = ck::false_type;
using type = Default;
};
template <class Default, template <class...> class Op, class... Args>
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
struct detector<Default, ck::void_t<Op<Args...>>, Op, Args...>
{
using value_t = std::true_type;
using value_t = ck::true_type;
using type = Op<Args...>;
};
} // namespace detail
@@ -32,12 +32,12 @@ template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
template <typename T>
using is_pack2_invocable_t = decltype(std::declval<T&>().is_pack2_invocable);
using is_pack2_invocable_t = decltype(ck::declval<T&>().is_pack2_invocable);
template <typename T>
using is_pack4_invocable_t = decltype(std::declval<T&>().is_pack4_invocable);
using is_pack4_invocable_t = decltype(ck::declval<T&>().is_pack4_invocable);
template <typename T>
using is_pack8_invocable_t = decltype(std::declval<T&>().is_pack8_invocable);
using is_pack8_invocable_t = decltype(ck::declval<T&>().is_pack8_invocable);
} // namespace ck

View File

@@ -1,8 +1,10 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ostream>
#pragma once
#ifndef __HIPCC_RTC__
#include <ostream>
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
@@ -26,6 +28,7 @@ constexpr LoopScheduler make_default_loop_scheduler()
} // namespace ck
#ifndef __HIPCC_RTC__
inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
{
switch(s)
@@ -36,3 +39,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
}
return os;
}
#endif

View File

@@ -30,7 +30,7 @@ struct MagicDivision
// WARNING: magic division is only applicable for division inside this range.
// You should use the return value of CalculateMagicNumbers, if division is not inside this
// range. The "else" logic below is to quiet down run-time error.
if(divisor >= 1 && divisor <= INT32_MAX)
if(divisor >= 1 && divisor <= ck::NumericLimits<int32_t>::Max())
{
uint32_t shift = 0;
for(shift = 0; shift < 32; ++shift)

View File

@@ -18,6 +18,7 @@ namespace math {
extern "C" __device__ float __ocml_native_recip_f32(float);
#endif
#ifndef __HIPCC_RTC__
// math functions for the host, some are implemented by calling C++ std functions
static inline __host__ float abs(float x) { return std::abs(x); };
@@ -457,6 +458,7 @@ inline __host__ double expm1<double>(double x)
{
return std::expm1(x);
}
#endif
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
@@ -920,5 +922,23 @@ inline __device__ double expm1<double>(double x)
return expm1(x);
};
template <typename T>
inline __device__ T cos(T x)
{
return ck::type_convert<T>(cosf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float cos<float>(float x)
{
return cosf(x);
};
template <>
inline __device__ double cos<double>(double x)
{
return cos(x);
};
} // namespace math
} // namespace ck

View File

@@ -7,7 +7,7 @@ namespace ck {
// Pseudo random number generator
// version for fp32
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<float, T>{}, bool> = false>
template <typename T, uint32_t seed_t, ck::enable_if_t<ck::is_same<float, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
@@ -23,7 +23,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// version for fp16
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false>
template <typename T, uint32_t seed_t, ck::enable_if_t<ck::is_same<half_t, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{
uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
@@ -40,12 +40,18 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
// return 0 if data is not fp16 or fp32
template <typename T,
uint32_t seed_t,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false>
ck::enable_if_t<!(ck::is_same<float, T>{} || ck::is_same<half_t, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{
#ifdef __HIPCC_RTC__
static_cast<void>(id);
static_cast<void>(val);
static_cast<void>(seed);
#else
std::ignore = id;
std::ignore = val;
std::ignore = seed;
#endif
return 0;
}

View File

@@ -3,7 +3,9 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <ostream>
#endif
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
@@ -900,6 +902,7 @@ using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
} // namespace ck
#ifndef __HIPCC_RTC__
template <ck::index_t... Is>
std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
{
@@ -910,3 +913,4 @@ std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
os << S::At(S::Size() - ck::Number<1>{}).value << "}";
return os;
}
#endif

View File

@@ -32,7 +32,7 @@ struct TupleElementKeyData
template <typename T,
typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value,
bool>::type = false>
__host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward<T>(v))
__host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(ck::forward<T>(v))
{
}
@@ -67,7 +67,7 @@ get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
template <typename Key, typename Data>
__host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{
return std::forward(x.mData);
return ck::forward(x.mData);
}
template <typename Indices, typename... Xs>
@@ -83,13 +83,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
!is_same<remove_cvref_t<Y>, TupleImpl>::value,
bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y)
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
: TupleElementKeyData<TupleElementKey<Is>, Xs>(ck::forward<Y>(y))...
{
}
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr TupleImpl(Ys&&... ys)
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
: TupleElementKeyData<TupleElementKey<Is>, Xs>(ck::forward<Ys>(ys))...
{
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
"wrong! inconsistent size");
@@ -123,14 +123,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template <typename Y,
typename enable_if<sizeof...(Xs) == 1 && !is_same<remove_cvref_t<Y>, Tuple>::value,
bool>::type = false>
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
__host__ __device__ constexpr Tuple(Y&& y) : base(ck::forward<Y>(y))
{
}
template <typename... Ys,
typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
false>
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(ck::forward<Ys>(ys)...)
{
}
@@ -210,7 +210,7 @@ using tuple_element_t = typename tuple_element<I, TTuple>::type;
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{
return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...);
return Tuple<remove_cvref_t<Xs>...>(ck::forward<Xs>(xs)...);
}
// https://en.cppreference.com/w/cpp/utility/tuple/tie

View File

@@ -29,7 +29,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
const Tuple<Y&...>& ty)
{
return unpack2(
[&](auto&&... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
[&](auto&&... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
@@ -38,7 +38,7 @@ template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{
return unpack2(
[&](auto... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; },
[&](auto... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx,
ty);
}
@@ -157,6 +157,7 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
}
}
#ifndef __HIPCC_RTC__
template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
@@ -165,6 +166,7 @@ __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
{
return (is_detected<is_tuple, Ts>::value || ...);
}
#endif
template <index_t depth = 0, typename T>
__host__ __device__ constexpr auto TupleDepth(const T&)

View File

@@ -8,6 +8,158 @@
#include "ck/utility/enable_if.hpp"
namespace ck {
#ifdef __HIPCC_RTC__
template <bool B>
using bool_constant = integral_constant<bool, B>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \
struct name : bool_constant<__##name(T)> \
{ \
}
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT2(name) \
template <class T, class U> \
struct name : bool_constant<__##name(T, U)> \
{ \
}
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAITN(name) \
template <class... Ts> \
struct name : bool_constant<__##name(Ts...)> \
{ \
}
CK_BUILTIN_TYPE_TRAIT1(is_class);
CK_BUILTIN_TYPE_TRAIT1(is_pointer);
CK_BUILTIN_TYPE_TRAIT1(is_reference);
CK_BUILTIN_TYPE_TRAIT1(is_trivially_copyable);
CK_BUILTIN_TYPE_TRAIT1(is_unsigned);
CK_BUILTIN_TYPE_TRAIT2(is_base_of);
template <class T>
struct remove_cv
{
using type = T;
};
template <class T>
struct remove_cv<const T> : remove_cv<T>
{
};
template <class T>
struct remove_cv<volatile T> : remove_cv<T>
{
};
template <class T>
struct remove_reference
{
typedef T type;
};
template <class T>
struct remove_reference<T&>
{
typedef T type;
};
template <class T>
struct remove_reference<T&&>
{
typedef T type;
};
template <class T>
struct remove_pointer
{
typedef T type;
};
template <class T>
struct remove_pointer<T*>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* const>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* volatile>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* const volatile>
{
typedef T type;
};
template <typename T>
constexpr T&& forward(typename remove_reference<T>::type& t_) noexcept
{
return static_cast<T&&>(t_);
}
template <typename T>
constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
{
return static_cast<T&&>(t_);
}
// TODO
template<class T> struct is_const : false_type {};
template<class T> struct is_const<const T> : true_type {};
template< class T >
inline constexpr bool is_const_v = is_const<T>::value;
template< class T >
inline constexpr bool is_reference_v = is_reference<T>::value;
template<class T> struct remove_const { typedef T type; };
template<class T> struct remove_const<const T> { typedef T type; };
template< class T >
using remove_const_t = typename remove_const<T>::type;
template< class T >
inline constexpr bool is_class_v = is_class<T>::value;
template< class T >
inline constexpr bool is_trivially_copyable_v = is_trivially_copyable<T>::value;
template< class... >
using void_t = void;
using __hip::declval;
#else
#include <utility>
#include <type_traits>
using std::forward;
using std::is_base_of;
using std::is_class;
using std::is_pointer;
using std::is_reference;
using std::is_trivially_copyable;
using std::is_unsigned;
using std::remove_cv;
using std::remove_pointer;
using std::remove_reference;
using std::is_const_v;
using std::is_reference_v;
using std::remove_const_t;
using std::is_class_v;
using std::is_trivially_copyable_v;
using std::void_t;
using std::false_type;
using std::true_type;
using std::declval;
#endif
template <typename X, typename Y>
struct is_same : public integral_constant<bool, false>
@@ -23,19 +175,19 @@ template <typename X, typename Y>
inline constexpr bool is_same_v = is_same<X, Y>::value;
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;
using remove_reference_t = typename remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type;
using remove_cv_t = typename remove_cv<T>::type;
template <typename T>
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
using remove_cvref_t = remove_cv_t<remove_reference_t<T>>;
template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type;
using remove_pointer_t = typename remove_pointer<T>::type;
template <typename T>
inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
inline constexpr bool is_pointer_v = is_pointer<T>::value;
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y bit_cast(const X& x)

View File

@@ -17,10 +17,10 @@ namespace ck {
// Convert X to Y, both X and Y are non-const data types.
template <typename Y,
typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
ck::enable_if_t<!(ck::is_const_v<Y> || ck::is_const_v<X>), bool> = false>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
return static_cast<Y>(x);
}
@@ -28,13 +28,13 @@ __host__ __device__ constexpr Y type_convert(X x)
// Convert X to Y, either X or Y is a const data type.
template <typename Y,
typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
ck::enable_if_t<ck::is_const_v<Y> || ck::is_const_v<X>, bool> = false>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
using NonConstY = std::remove_const_t<Y>;
using NonConstX = std::remove_const_t<X>;
using NonConstY = ck::remove_const_t<Y>;
using NonConstX = ck::remove_const_t<X>;
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
}
@@ -104,7 +104,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
return static_cast<Y>(x);
}
@@ -166,7 +166,7 @@ template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<long_index_t>(&x), x);
#if defined(__gfx94__)
union
{
@@ -206,7 +206,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<long_index_t>(&x), x);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
@@ -218,7 +218,7 @@ template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<long_index_t>(&x), x);
#if defined(__gfx94__)
union
{
@@ -258,7 +258,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<long_index_t>(&x), x);
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
@@ -501,6 +501,7 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#endif
}
#ifndef __HIPCC_RTC__
template <typename Y, typename X, std::size_t NumElems>
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
const std::array<X, NumElems>& x)
@@ -510,6 +511,7 @@ inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
y[i] = type_convert<Y>(x[i]);
}
}
#endif
template <typename Y, typename X, index_t NumElems>
inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x)