diff --git a/CMakeLists.txt b/CMakeLists.txt
index fd321f7722..71fdb91d8e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -128,6 +128,8 @@ list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/ll
message("GPU_TARGETS= ${GPU_TARGETS}")
+option(CK_BUILD_HOST_LIB, "Only build the CK JIT Helper Library" OFF)
+
find_package(hip)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
@@ -254,6 +256,7 @@ elseif(CK_PARALLEL_COMPILE_JOBS)
message(WARNING "Job pooling is only available with Ninja generators.")
endif()
+if (NOT CK_BUILD_HOST_LIB)
option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF)
@@ -275,6 +278,8 @@ set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
link_libraries(Threads::Threads)
+endif() # NOT CK_BUILD_HOST_LIB
+
## C++
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
@@ -291,6 +296,8 @@ if(USE_GLIBCXX_ASSERTIONS)
add_compile_options(-Wp,-D_GLIBCXX_ASSERTIONS)
endif()
+if (NOT CK_BUILD_HOST_LIB)
+
## HIP
set(CMAKE_HIP_PLATFORM amd)
set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER})
@@ -346,6 +353,8 @@ else()
add_compile_definitions(__HIP_PLATFORM_HCC__=1)
endif()
+endif() # NOT CK_BUILD_HOST_LIB
+
## tidy
include(EnableCompilerWarnings)
set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
@@ -499,6 +508,8 @@ include_directories(BEFORE
${HIP_INCLUDE_DIRS}
)
+if (NOT CK_BUILD_HOST_LIB)
+
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
@@ -506,6 +517,8 @@ if(BUILD_DEV)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
+endif() # NOT CK_BUILD_HOST_LIB
+
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
add_compile_options(-fcolor-diagnostics)
endif()
@@ -515,6 +528,8 @@ endif()
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
+if (NOT CK_BUILD_HOST_LIB)
+
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp")
file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*)
set(CK_DEVICE_INSTANCES)
@@ -590,6 +605,18 @@ if(NOT DEFINED PROFILER_ONLY AND (GPU_TARGETS MATCHES "gfx9" OR DEFINED INSTANCE
add_subdirectory(codegen)
endif()
+else() # NOT CK_BUILD_HOST_LIB
+
+if(GPU_TARGETS MATCHES "gfx9")
+ rocm_package_setup_component(ck_host
+ LIBRARY_NAME composablekernel
+ PACKAGE_NAME ck_host
+ )
+ add_subdirectory(codegen)
+endif()
+
+endif() # NOT CK_BUILD_HOST_LIB
+
#Create an interface target for the include only files and call it "composablekernels"
include(CMakePackageConfigHelpers)
@@ -627,4 +654,4 @@ rocm_create_package(
MAINTAINER "MIOpen Kernels Dev Team
"
LDCONFIG
HEADER_ONLY
-)
+)
\ No newline at end of file
diff --git a/Config.cmake.in b/Config.cmake.in
index 2861a28f49..a260bc9e6e 100644
--- a/Config.cmake.in
+++ b/Config.cmake.in
@@ -1,6 +1,6 @@
@PACKAGE_INIT@
-set(_composable_kernel_supported_components device_other_operations device_gemm_operations device_conv_operations device_mha_operations device_contraction_operations device_reduction_operations utility)
+set(_composable_kernel_supported_components device_other_operations device_gemm_operations device_conv_operations device_mha_operations device_contraction_operations device_reduction_operations utility ck_host)
foreach(_comp ${composable_kernel_FIND_COMPONENTS})
if(NOT _comp IN_LIST _composable_kernel_supported_components)
diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt
index 3b3e9f06ee..4bf065d9df 100644
--- a/codegen/CMakeLists.txt
+++ b/codegen/CMakeLists.txt
@@ -31,12 +31,21 @@ file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
##message(STATUS "SOURCE_FILES: ${SOURCES}")
# TODO: Use object library
add_library(ck_host STATIC ${SOURCES})
-target_link_libraries(ck_host PRIVATE ck_headers)
+add_library(composable_kernel::ck_host ALIAS ck_host)
set_target_properties(ck_host PROPERTIES
LINKER_LANGUAGE CXX
POSITION_INDEPENDENT_CODE ON)
+target_include_directories(ck_host SYSTEM PRIVATE
+ $
+ # $
+ $
+ $
+)
+
+target_link_libraries(ck_host PRIVATE $)
+
target_include_directories(ck_host PUBLIC
$
)
@@ -45,9 +54,18 @@ add_executable(ck-template-driver driver/main.cpp)
target_link_libraries(ck-template-driver ck_host)
rocm_install(
- TARGETS ck_host ck_headers
+ TARGETS ck_host
EXPORT ck_hostTargets
)
rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
-add_subdirectory(test)
+rocm_install(
+ EXPORT ck_hostTargets
+ FILE composable_kernelck_hostTargets.cmake
+ NAMESPACE composable_kernel::
+ DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
+)
+
+if (NOT CK_BUILD_HOST_LIB)
+ add_subdirectory(test)
+endif()
diff --git a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp
new file mode 100644
index 0000000000..d992b04536
--- /dev/null
+++ b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp
@@ -0,0 +1,58 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
+
+#pragma once
+
+#include
+#include
+#include
+#include "ck/host/types.hpp"
+#include "ck/host/operation/gemm.hpp"
+#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
+
+namespace ck {
+namespace host {
+namespace device_batched_gemm_softmax_gemm {
+
+// defines all values need for an instance of fwd conv
+struct Operation_Xdl_CShuffle
+{
+ // returns a vector of instances, only given fusion operators: will use default problem spec
+ static std::vector>
+ CreateOperations(const std::string& prologue, const std::string& epilogue);
+ // returns a vector of instances, given a problem spec and fusion operators
+ static std::vector
+ CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
+ TensorDesc A{};
+ TensorDesc B{};
+ TensorDesc B1{};
+ TensorDesc C{};
+ std::string a_elem_op = PassThrough;
+ std::string b_elem_op = PassThrough;
+ std::string b1_elem_op = PassThrough;
+ std::string c_elem_op = PassThrough;
+ std::string acc_elem_op = Scale;
+ std::string prologue = "";
+ std::string epilogue = "";
+ std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
+ // tuning parameters
+ operation::TileDescGemmSoftmaxGemm tile_desc{};
+ operation::BlockTransferDesc a_block_transfer{};
+ operation::BlockTransferDesc b0_block_transfer{};
+ operation::BlockTransferDesc b1_block_transfer{};
+ operation::CShuffleDesc cshuffle{};
+ operation::CBlockTransferDesc c_block_transfer{};
+
+ bool mask_out_upper_triangle = false;
+
+ // functions to update fusion operators if provided
+ void update_prologue(const std::string& prologue);
+ void update_epilogue(const std::string& epilogue);
+ /**constexpr**/ bool IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_);
+ // returns a templated instance
+ Solution ToSolution() const;
+};
+
+} // namespace device_batched_gemm_softmax_gemm
+} // namespace host
+} // namespace ck
diff --git a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp
new file mode 100644
index 0000000000..428034a3ba
--- /dev/null
+++ b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp
@@ -0,0 +1,47 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
+
+#pragma once
+
+#include
+#include
+#include
+#include "ck/host/types.hpp"
+
+namespace ck {
+namespace host {
+namespace device_batched_gemm_softmax_gemm {
+
+// defines the problem specification for a GEMM operation
+struct Problem
+{
+ std::size_t M = 0;
+ std::size_t N = 0;
+ std::size_t K = 0;
+ std::size_t O = 0;
+ bool TransA = false;
+ bool TransB = false;
+ bool TransB1 = false;
+ bool TransC = false;
+ DataType ADataType = DataType::Half;
+ DataType BDataType = DataType::Half;
+ DataType B1DataType = DataType::Half;
+ DataType CDataType = DataType::Half;
+ std::string AElementOp = PassThrough;
+ std::string BElementOp = PassThrough;
+ std::string B1ElementOp = PassThrough;
+ std::string CElementOp = PassThrough;
+ std::string AccElementOp = Scale;
+
+ // returns the correct device op file for the operation
+ std::string GetIncludeHeader() const;
+
+ // returns a list of instances based on the problem spec and provided fusion operations
+ std::vector GetSolutions(const std::string& arch,
+ const std::string& prologue,
+ const std::string& epilogue) const;
+};
+
+} // namespace device_batched_gemm_softmax_gemm
+} // namespace host
+} // namespace ck
diff --git a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
index 359da7d8cf..e5eeb6be15 100644
--- a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
+++ b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp
@@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle
operation::BlockTransferDesc b_block_transfer{};
operation::CShuffleDesc cshuffle{};
operation::CBlockTransferDesc c_block_transfer{};
+ LoopScheduler loop_scheduler{};
+ PipelineVersion pipeline_version{};
// functions to update fusion operators if provided
void update_prologue(const std::string& prologue);
diff --git a/codegen/include/ck/host/operation/gemm.hpp b/codegen/include/ck/host/operation/gemm.hpp
index 84ef92f0a0..790c51e773 100644
--- a/codegen/include/ck/host/operation/gemm.hpp
+++ b/codegen/include/ck/host/operation/gemm.hpp
@@ -23,6 +23,26 @@ struct TileDesc
int n_Xdl_per_wave = 0;
int num_gemmk_prefetch_stage = 0;
};
+
+struct TileDescGemmSoftmaxGemm
+{
+ int block_size = 0;
+ int gemm01_m_per_block = 0;
+ int gemm0_n_per_block = 0;
+ int gemm0_k_per_block = 0;
+ int gemm1_n_per_block = 0;
+ int gemm1_k_per_block = 0;
+ int ak1 = 0;
+ int bk1 = 0;
+ int b1k1 = 0;
+ int m_per_XDL = 0;
+ int n_per_XDL = 0;
+ int gemm0_m_Xdl_per_wave = 0;
+ int gemm0_n_Xdl_per_wave = 0;
+ int gemm1_n_Xdl_per_wave = 0;
+ int num_gemmk_prefetch_stage = 0;
+};
+
struct BlockTransferDesc
{
std::string thread_cluster_length = "";
diff --git a/codegen/include/ck/host/types.hpp b/codegen/include/ck/host/types.hpp
index 8bad7bf89c..b05e134176 100644
--- a/codegen/include/ck/host/types.hpp
+++ b/codegen/include/ck/host/types.hpp
@@ -66,6 +66,20 @@ enum class GemmType
};
std::string ToString(GemmType gt);
+enum class LoopScheduler
+{
+ Default,
+ Interwave,
+};
+std::string ToString(LoopScheduler ls);
+
+enum class PipelineVersion
+{
+ v1,
+ v2
+};
+std::string ToString(PipelineVersion pv);
+
struct TensorDesc
{
DataType element;
@@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...});
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
+constexpr const char* Scale = "ck::tensor_operation::element_wise::Scale";
} // namespace host
} // namespace ck
diff --git a/codegen/src/device_batched_gemm_softmax_gemm.cpp b/codegen/src/device_batched_gemm_softmax_gemm.cpp
new file mode 100644
index 0000000000..cf140ead1d
--- /dev/null
+++ b/codegen/src/device_batched_gemm_softmax_gemm.cpp
@@ -0,0 +1,38 @@
+
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
+
+#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
+#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
+#include "ck/host/utils.hpp"
+#include
+
+namespace ck {
+namespace host {
+namespace device_batched_gemm_softmax_gemm {
+
+// return the relevant device op file based on the operation
+std::string Problem::GetIncludeHeader() const
+{
+ return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp";
+}
+
+// returns templated instances when provided with a problem specification
+std::vector Problem::GetSolutions(const std::string& arch,
+ const std::string& prologue,
+ const std::string& epilogue) const
+{
+ if(get_xdlop_archs().count(arch) == 0)
+ return {};
+ auto ops = ck::host::device_batched_gemm_softmax_gemm::Operation_Xdl_CShuffle::CreateOperations(
+ *this, prologue, epilogue); // obtains vector of instances
+ std::vector result;
+ std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
+ return op.ToSolution(); // template instance with correct values
+ });
+ return result;
+}
+
+} // namespace device_batched_gemm_softmax_gemm
+} // namespace host
+} // namespace ck
diff --git a/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp
new file mode 100644
index 0000000000..aa68dbe337
--- /dev/null
+++ b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp
@@ -0,0 +1,412 @@
+// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
+
+#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
+#include "ck/host/stringutils.hpp"
+#include "ck/host/utils.hpp"
+#include
+
+namespace ck {
+namespace host {
+namespace device_batched_gemm_softmax_gemm {
+
+// calculate appropriate Gemm Specification based on input tensor dimensions
+std::string GetGemmSpec(const std::size_t m,
+ const std::size_t n,
+ const std::size_t k,
+ const std::size_t n1,
+ const std::size_t m_per_block,
+ const std::size_t n_per_block,
+ const std::size_t k_per_block,
+ const std::size_t n1_per_block)
+{
+ std::string spec = "";
+ if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
+ spec += "M";
+ if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
+ spec += "N";
+ if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
+ spec += "K";
+ if(integer_divide_ceil(n1, n1_per_block) * n1_per_block - n1 != 0)
+ spec += "O";
+ if(spec == "")
+ return "ck::tensor_operation::device::GemmSpecialization::Default";
+
+ return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
+}
+
+// function to update prologue/epilogue with user provided operation
+void Operation_Xdl_CShuffle::update_prologue(const std::string& pro)
+{
+ if(!prologue.empty())
+ {
+ this->prologue = pro;
+ // TODO
+ // this->cde_elem_op = "CDEElementOp";
+ }
+ else
+ {
+ this->prologue = "";
+ }
+}
+
+void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
+{
+ if(!epilogue.empty())
+ {
+ this->epilogue = epi;
+ // TODO
+ // this->cde_elem_op = "CDEElementOp";
+ }
+ else
+ {
+ this->epilogue = "";
+ }
+}
+
+// accounts for all possible combinations of Row/Col major
+static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
+
+// Hard-code tuning parameters in modularized fashion, string them together into a vector of
+// instances
+std::vector Operation_Xdl_CShuffle::CreateOperations(
+ const Problem& prob, const std::string& prologue, const std::string& epilogue)
+{
+ std::vector result;
+
+ std::vector tile_descriptions = {
+ // clang-format off
+// Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| NumGemmK|
+// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch|
+// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage|
+// | | | | | | | | | | | Wave| Wave| Wave| |
+ { 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1},
+ { 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1},
+ { 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1},
+ { 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1},
+ { 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
+ { 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
+ { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
+ { 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
+ { 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
+ { 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
+ { 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
+ { 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
+// Padded fallback kernel
+ { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
+ { 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1},
+// Irregular k
+ { 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1},
+ { 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1},
+ { 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1},
+ { 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1},
+ { 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1},
+ { 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1},
+ // clang-format on
+ };
+
+ const std::vector a_block_descriptions = {
+ // clang-format off
+// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds|
+// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
+// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
+// | | | | | | |
+ { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
+ { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
+ { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
+ { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
+ { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
+ { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
+ { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
+ { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
+ { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
+ { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
+ { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
+ { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
+// Padded fallback kernel
+ { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
+ { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
+// Irregular k
+ { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
+ { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
+ { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
+ { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
+ { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
+ { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
+ // clang-format on
+ };
+
+ const std::vector b1_block_descriptions = {
+ // clang-format off
+// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds|
+// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
+// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
+// | | | | | | |
+ { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+// Padded fallback kernel
+ { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+// Irregular k
+ { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
+ // clang-format on
+ };
+
+ std::vector cshuffle_descriptions = {
+ // clang-format off
+// CShuffle| CShuffle|
+// MXdlPerWave| NXdlPerWave|
+// PerShuffle| PerShuffle|
+// | |
+ { 1, 2},
+ { 1, 2},
+ { 1, 2},
+ { 1, 2},
+ { 1, 2},
+ { 1, 2},
+ { 1, 2},
+ { 1, 2},
+ { 1, 8},
+ { 1, 4},
+ { 1, 8},
+ { 1, 4},
+// Padded fallback kernel
+ { 1, 2},
+ { 1, 2},
+// Irregular k
+ { 1, 2},
+ { 1, 2},
+ { 1, 2},
+ { 1, 2},
+ { 1, 2},
+ { 1, 2},
+ // clang-format on
+ };
+
+ std::vector c_block_descriptions = {
+ // clang-format off
+// CBlockTransferClusterLengths| CBlockTransfer
+// _MBlock_MWaveMPerXdl| ScalarPerVector
+// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
+// |
+ { S<1, 32, 1, 8>, 8},
+ { S<1, 32, 1, 8>, 8},
+ { S<1, 32, 1, 8>, 8},
+ { S<1, 32, 1, 8>, 8},
+ { S<1, 32, 1, 8>, 8},
+ { S<1, 32, 1, 8>, 8},
+ { S<1, 32, 1, 8>, 8},
+ { S<1, 32, 1, 8>, 8},
+ { S<1, 16, 1,16>, 8},
+ { S<1, 32, 1, 8>, 8},
+ { S<1, 16, 1,16>, 8},
+ { S<1, 32, 1, 8>, 8},
+// Padded fallback kernel
+ { S<1, 32, 1, 8>, 8},
+ { S<1, 32, 1, 8>, 8},
+// Irregular k
+ { S<1, 32, 1, 8>, 8},
+ { S<1, 32, 1, 8>, 8},
+ { S<1, 32, 1, 8>, 8},
+ { S<1, 32, 1, 8>, 8},
+ { S<1, 32, 1, 8>, 8},
+ { S<1, 32, 1, 8>, 8},
+ // clang-format on
+ };
+
+ assert(tile_descriptions.size() == a_block_descriptions.size());
+ assert(tile_descriptions.size() == b1_block_descriptions.size());
+ assert(tile_descriptions.size() == cshuffle_descriptions.size());
+ assert(tile_descriptions.size() == c_block_descriptions.size());
+
+ // Put all values together into a single operation > store into the result vector
+ for(std::size_t i = 0; i < tile_descriptions.size(); i++)
+ {
+ Operation_Xdl_CShuffle x;
+ x.tile_desc = tile_descriptions[i];
+ x.a_block_transfer = a_block_descriptions[i];
+ x.b0_block_transfer = a_block_descriptions[i]; // b0 same as a
+ x.b1_block_transfer = b1_block_descriptions[i];
+ x.cshuffle = cshuffle_descriptions[i];
+ x.c_block_transfer = c_block_descriptions[i];
+ x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
+ x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
+ x.B1 = TensorDesc{prob.B1DataType, ToLayout(prob.TransB1)};
+ x.C = TensorDesc{prob.CDataType, ToLayout(prob.TransC)};
+ x.a_elem_op = prob.AElementOp;
+ x.b_elem_op = prob.BElementOp;
+ x.b1_elem_op = prob.B1ElementOp;
+ x.c_elem_op = prob.CElementOp;
+ x.acc_elem_op = prob.AccElementOp;
+ x.gemm_specialization = GetGemmSpec(prob.M,
+ prob.N,
+ prob.K,
+ prob.O,
+ x.tile_desc.gemm01_m_per_block,
+ x.tile_desc.gemm0_n_per_block,
+ x.tile_desc.gemm0_k_per_block,
+ x.tile_desc.gemm1_n_per_block);
+ x.update_prologue(prologue);
+ x.update_epilogue(epilogue);
+ x.mask_out_upper_triangle = true;
+ result.push_back(x);
+
+ x.mask_out_upper_triangle = false;
+ result.push_back(x);
+ }
+ return result;
+}
+
+// set up instances when not provided with a problem specification, use default operation values and
+// all possible layout combinations
+std::vector>
+Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue)
+{
+ Problem prob;
+ prob.TransA = false;
+ prob.TransB = true;
+ prob.TransB1 = false;
+ prob.TransC = false;
+
+ return {CreateOperations(prob, prologue, epilogue)};
+}
+
+static const char* const DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate =
+ "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<${LayoutA}, "
+ "${LayoutB0}, ${LayoutB1}, ${LayoutC}, ${ADataType}, ${B0DataType}, ${B1DataType}, "
+ "${CDataType}, ${AccDataType}, ${CShuffleDataType}, ${AElementwiseOperation}, "
+ "${B0ElementwiseOperation}, ${Acc0ElementwiseOperation}, ${B1ElementwiseOperation}, "
+ "${CElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, "
+ "${Gemm01MPerBlock}, ${Gemm0NPerBlock}, ${Gemm0KPerBlock}, ${Gemm1NPerBlock}, "
+ "${Gemm1KPerBlock}, ${AK1}, ${BK1}, ${B1K1}, ${MPerXDL}, ${NPerXDL}, ${Gemm0MXdlPerWave}, "
+ "${Gemm0NXdlPerWave}, ${Gemm1NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, "
+ "${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, "
+ "${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, "
+ "${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, "
+ "${B0BlockTransferThreadClusterLengths_BK0_N_BK1}, "
+ "${B0BlockTransferThreadClusterArrangeOrder}, ${B0BlockTransferSrcAccessOrder}, "
+ "${B0BlockTransferSrcVectorDim}, ${B0BlockTransferSrcScalarPerVector}, "
+ "${B0BlockTransferDstScalarPerVector_BK1}, ${B0BlockLdsExtraN}, "
+ "${B1BlockTransferThreadClusterLengths_BK0_N_BK1}, "
+ "${B1BlockTransferThreadClusterArrangeOrder}, ${B1BlockTransferSrcAccessOrder}, "
+ "${B1BlockTransferSrcVectorDim}, ${B1BlockTransferSrcScalarPerVector}, "
+ "${B1BlockTransferDstScalarPerVector_BK1}, ${B1BlockLdsExtraN}, "
+ "${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
+ "${CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl}, "
+ "${CBlockTransferScalarPerVector_NWaveNPerXdl}, ${MaskOutUpperTriangle}>";
+
+// use hardcoded instances from vector of operations to substitute values into instance template
+Solution Operation_Xdl_CShuffle::ToSolution() const
+{
+ std::unordered_map values = {
+ {"name",
+ std::to_string(this->tile_desc.block_size) + "_" +
+ std::to_string(this->tile_desc.gemm01_m_per_block) + "_" +
+ std::to_string(this->tile_desc.gemm0_n_per_block) + "_" +
+ std::to_string(this->tile_desc.gemm0_k_per_block) + "_" +
+ std::to_string(this->tile_desc.gemm1_n_per_block) + "_" +
+ std::to_string(this->tile_desc.gemm1_k_per_block) + "_" +
+ std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" +
+ std::to_string(this->tile_desc.b1k1) + "_" +
+ std::to_string(this->tile_desc.m_per_XDL) + "_" +
+ std::to_string(this->tile_desc.n_per_XDL) + "_" +
+ std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" +
+ std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave) + "_" +
+ std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)},
+ {"LayoutA", ToString(this->A.layout)},
+ {"LayoutB0", ToString(this->B.layout)},
+ {"LayoutB1", ToString(this->B1.layout)},
+ {"LayoutC", ToString(this->C.layout)},
+ {"ADataType", ToString(this->A.element)},
+ {"B0DataType", ToString(this->B.element)},
+ {"B1DataType", ToString(this->B1.element)},
+ {"CDataType", ToString(this->C.element)},
+ {"AccDataType", ToString(DataType::Float)},
+ {"CShuffleDataType", ToString(DataType::Half)},
+ {"AElementwiseOperation", this->a_elem_op},
+ {"B0ElementwiseOperation", this->b_elem_op},
+ {"Acc0ElementwiseOperation", this->acc_elem_op},
+ {"B1ElementwiseOperation", this->b1_elem_op},
+ {"CElementwiseOperation", this->c_elem_op},
+ {"GemmSpecialization", this->gemm_specialization},
+ {"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)},
+ {"BlockSize", std::to_string(this->tile_desc.block_size)},
+ {"Gemm01MPerBlock", std::to_string(this->tile_desc.gemm01_m_per_block)},
+ {"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)},
+ {"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)},
+ {"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)},
+ {"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)},
+ {"AK1", std::to_string(this->tile_desc.ak1)},
+ {"BK1", std::to_string(this->tile_desc.bk1)},
+ {"B1K1", std::to_string(this->tile_desc.b1k1)},
+ {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)},
+ {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)},
+ {"Gemm0MXdlPerWave", std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave)},
+ {"Gemm0NXdlPerWave", std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave)},
+ {"Gemm1NXdlPerWave", std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)},
+ {"ABlockTransferThreadClusterLengths_AK0_M_AK1",
+ this->a_block_transfer.thread_cluster_length},
+ {"ABlockTransferThreadClusterArrangeOrder",
+ this->a_block_transfer.thread_cluster_arrange_order},
+ {"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order},
+ {"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)},
+ {"ABlockTransferSrcScalarPerVector",
+ std::to_string(this->a_block_transfer.src_scalar_per_vector)},
+ {"ABlockTransferDstScalarPerVector_AK1",
+ std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)},
+ {"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)},
+ {"B0BlockTransferThreadClusterLengths_BK0_N_BK1",
+ this->b0_block_transfer.thread_cluster_length},
+ {"B0BlockTransferThreadClusterArrangeOrder",
+ this->b0_block_transfer.thread_cluster_arrange_order},
+ {"B0BlockTransferSrcAccessOrder", this->b0_block_transfer.src_access_order},
+ {"B0BlockTransferSrcVectorDim", std::to_string(this->b0_block_transfer.src_vec_dim)},
+ {"B0BlockTransferSrcScalarPerVector",
+ std::to_string(this->b0_block_transfer.src_scalar_per_vector)},
+ {"B0BlockTransferDstScalarPerVector_BK1",
+ std::to_string(this->b0_block_transfer.dst_scalar_per_vector_k1)},
+ {"B0BlockLdsExtraN", std::to_string(this->b0_block_transfer.lds_add_extra_dim)},
+ {"B1BlockTransferThreadClusterLengths_BK0_N_BK1",
+ this->b1_block_transfer.thread_cluster_length},
+ {"B1BlockTransferThreadClusterArrangeOrder",
+ this->b1_block_transfer.thread_cluster_arrange_order},
+ {"B1BlockTransferSrcAccessOrder", this->b1_block_transfer.src_access_order},
+ {"B1BlockTransferSrcVectorDim", std::to_string(this->b1_block_transfer.src_vec_dim)},
+ {"B1BlockTransferSrcScalarPerVector",
+ std::to_string(this->b1_block_transfer.src_scalar_per_vector)},
+ {"B1BlockTransferDstScalarPerVector_BK1",
+ std::to_string(this->b1_block_transfer.dst_scalar_per_vector_k1)},
+ {"B1BlockLdsExtraN", std::to_string(this->b1_block_transfer.lds_add_extra_dim)},
+ {"CShuffleMXdlPerWavePerShuffle",
+ std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)},
+ {"CShuffleNXdlPerWavePerShuffle",
+ std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)},
+ {"CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl",
+ this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
+ {"CBlockTransferScalarPerVector_NWaveNPerXdl",
+ std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
+ {"MaskOutUpperTriangle", std::to_string(this->mask_out_upper_triangle)},
+ };
+
+ return Solution{InterpolateString(DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate, values),
+ std::move(values)};
+}
+
+} // namespace device_batched_gemm_softmax_gemm
+} // namespace host
+} // namespace ck
diff --git a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
index fff75c1962..f4b61ee99a 100644
--- a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
+++ b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp
@@ -62,6 +62,13 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
// accounts for all possible combinations of Row/Col major
static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
+
+
+// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1,
+
+// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
+
+
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
// instances
std::vector Operation_Xdl_CShuffle::CreateOperations(
@@ -83,6 +90,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(
{ 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1},
{ 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1},
{ 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1},
+// Irregular tile
+ { 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1},
// clang-format on
};
@@ -100,6 +109,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
+// Irregular tile
+ { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
// clang-format on
};
@@ -109,15 +120,17 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
// | | | | | | |
+ { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
+ { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
+ { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
+ { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
+ { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
+ { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
+ { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
+ { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
+// Irregular tile
+ { S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// clang-format on
- {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
- {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
- {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
- {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
- {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
- {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
- {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
- {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
};
std::vector b_block_descriptions_rowmajor = {
@@ -134,6 +147,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
+// Irregular tile
+ { S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
// clang-format on
};
@@ -151,6 +166,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
+// Irregular tile
+ { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
// clang-format on
};
@@ -167,6 +184,7 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(
{ 1, 1},
{ 1, 1},
{ 1, 1},
+ { 1, 1},
{ 1, 1},
// clang-format on
};
@@ -185,6 +203,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(
{ S<1, 16, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
{ S<1, 32, 1, 8>, 8},
+// Irregular tile
+ { S<1, 16, 1, 4>, 1},
// clang-format on
};
@@ -199,33 +219,44 @@ std::vector Operation_Xdl_CShuffle::CreateOperations(
assert(tile_descriptions.size() == cshuffle_descriptions.size());
assert(tile_descriptions.size() == c_block_descriptions.size());
- // Put all values together into a single operation > store into the result vector
- for(std::size_t i = 0; i < tile_descriptions.size(); i++)
+ const std::vector> scheduler_pipeline_descriptions =
+ {
+ {LoopScheduler::Default, PipelineVersion::v1},
+ {LoopScheduler::Interwave, PipelineVersion::v1},
+ {LoopScheduler::Default, PipelineVersion::v2},
+ };
+ for(auto [loop_scheduler, pipeline_version] : scheduler_pipeline_descriptions)
{
- Operation_Xdl_CShuffle x;
- x.tile_desc = tile_descriptions[i];
- x.a_block_transfer = a_block_descriptions[i];
- x.b_block_transfer = b_block_descriptions[i];
- x.cshuffle = cshuffle_descriptions[i];
- x.c_block_transfer = c_block_descriptions[i];
- x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
- x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
- x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)};
- x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) {
- return TensorDesc{dt, ToLayout(trans)};
- });
- x.a_elem_op = prob.AElementOp;
- x.b_elem_op = prob.BElementOp;
- x.cde_elem_op = prob.CDEElementOp;
- x.gemm_specialization = GetGemmSpec(prob.M,
- prob.N,
- prob.K,
- x.tile_desc.m_per_block,
- x.tile_desc.n_per_block,
- x.tile_desc.k_per_block);
- x.update_prologue(prologue);
- x.update_epilogue(epilogue);
- result.push_back(x);
+ // Put all values together into a single operation > store into the result vector
+ for(std::size_t i = 0; i < tile_descriptions.size(); i++)
+ {
+ Operation_Xdl_CShuffle x;
+ x.tile_desc = tile_descriptions[i];
+ x.a_block_transfer = a_block_descriptions[i];
+ x.b_block_transfer = b_block_descriptions[i];
+ x.cshuffle = cshuffle_descriptions[i];
+ x.c_block_transfer = c_block_descriptions[i];
+ x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
+ x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
+ x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)};
+ x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) {
+ return TensorDesc{dt, ToLayout(trans)};
+ });
+ x.a_elem_op = prob.AElementOp;
+ x.b_elem_op = prob.BElementOp;
+ x.cde_elem_op = prob.CDEElementOp;
+ x.gemm_specialization = GetGemmSpec(prob.M,
+ prob.N,
+ prob.K,
+ x.tile_desc.m_per_block,
+ x.tile_desc.n_per_block,
+ x.tile_desc.k_per_block);
+ x.loop_scheduler = loop_scheduler;
+ x.pipeline_version = pipeline_version;
+ x.update_prologue(prologue);
+ x.update_epilogue(epilogue);
+ result.push_back(x);
+ }
}
return result;
}
@@ -263,7 +294,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
- "${CDEBlockTransferScalarPerVector_NPerBlock}>";
+ "${CDEBlockTransferScalarPerVector_NPerBlock}, ${LoopScheduler}, ${PipelineVersion}>";
// use hardcoded instances from vector of operations to substitute values into instance template
Solution Operation_Xdl_CShuffle::ToSolution() const
@@ -336,6 +367,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
{"CDEBlockTransferScalarPerVector_NPerBlock",
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
+ {"LoopScheduler", ToString(this->loop_scheduler)},
+ {"PipelineVersion", ToString(this->pipeline_version)},
};
return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values),
diff --git a/codegen/src/types.cpp b/codegen/src/types.cpp
index a8a8b10c04..4757cab536 100644
--- a/codegen/src/types.cpp
+++ b/codegen/src/types.cpp
@@ -56,6 +56,26 @@ std::string ToString(GemmType gt)
throw std::runtime_error("Incorrect gemm type");
}
+std::string ToString(LoopScheduler ls)
+{
+ switch(ls)
+ {
+ case LoopScheduler::Default: return "ck::LoopScheduler::Default";
+ case LoopScheduler::Interwave: return "ck::LoopScheduler::Interwave";
+ }
+ throw std::runtime_error("Incorrect LoopScheduler type");
+}
+
+std::string ToString(PipelineVersion pv)
+{
+ switch(pv)
+ {
+ case PipelineVersion::v1: return "ck::PipelineVersion::v1";
+ case PipelineVersion::v2: return "ck::PipelineVersion::v2";
+ }
+ throw std::runtime_error("Incorrect PipelineVersion type");
+}
+
std::string SequenceStr(const std::vector& v)
{
return "ck::Sequence<" +
diff --git a/codegen/test/common.hpp b/codegen/test/common.hpp
index 99d4c64973..48afb7e042 100644
--- a/codegen/test/common.hpp
+++ b/codegen/test/common.hpp
@@ -15,7 +15,8 @@ std::vector get_headers_for_test()
auto hs = ck::host::GetHeaders();
std::transform(
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
- return {p.first, p.second};
+ std::string sec(p.second.begin(), p.second.end());
+ return {p.first, sec};
});
return result;
}
diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp
index bd7ef463fb..7874caacac 100644
--- a/codegen/test/gemm_multiple_d.cpp
+++ b/codegen/test/gemm_multiple_d.cpp
@@ -1,5 +1,7 @@
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
+#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
+#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
@@ -15,13 +17,59 @@
using half = _Float16;
// using half = __fp16;
+// NOLINTNEXTLINE
+const char* const disable_warning_pragma = R"__migraphx__(
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Weverything"
+${content}
+#pragma clang diagnostic pop
+)__migraphx__";
+
+template
+std::string ck_disable_warnings(P p)
+{
+ return ck::host::InterpolateString(disable_warning_pragma,
+ {{"content", std::string{p.data(), p.size()}}});
+}
+
+static std::unordered_map create_ck_header_strings()
+{
+ std::unordered_map result;
+ auto ck_headers = ck::host::GetHeaders();
+
+ std::transform(
+ ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto& p) {
+ return std::pair(p.first, ck_disable_warnings(p.second));
+ });
+ return result;
+}
+
+static std::vector create_ck_headers()
+{
+ static const auto& header_strings = create_ck_header_strings();
+ std::vector srcs;
+ std::transform(
+ header_strings.begin(), header_strings.end(), std::back_inserter(srcs), [&](auto& p) -> rtc::src_file {
+ std::string sec(p.second.begin(), p.second.end());
+ return {p.first, sec};
+ });
+ return srcs;
+}
+
+static inline const std::vector& ck_headers()
+{
+ static const auto& headers = create_ck_headers();
+ return headers;
+}
+
std::vector get_headers_for_test()
{
std::vector result;
auto hs = ck::host::GetHeaders();
std::transform(
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
- return {p.first, p.second};
+ std::string sec(p.second.begin(), p.second.end());
+ return {p.first, sec};
});
return result;
}
@@ -130,10 +178,13 @@ const std::string gemm_compile_check = R"__ck__(
extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) {
using G = ${template};
- constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})),
- ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})),
- ck::make_tuple(),
- ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n})));
+ constexpr auto desc =
+ G::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m},
+ ${k})),
+ ck::make_naive_tensor_descriptor(ck::make_tuple(${n},
+ ${k}), ck::make_tuple(1, ${n})), ck::make_tuple(),
+ ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m},
+ ${n})));
static_assert(desc.IsValid(), "Invalid ck gemm.");
@@ -163,23 +214,32 @@ TEST_CASE(test_problem_kernel)
std::string epilogue = "";
std::string prologue = "";
- for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue))
+ auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
+ std::cout << "Num solutions: " << solutions.size() << std::endl;
+ for(auto i = 0; i < solutions.size(); ++i)
{
- auto src = ck::host::InterpolateString(gemm_compile_check,
- {{"include", prob.GetIncludeHeader()},
- {"template", solution.ToTemplateString()},
- {"m", std::to_string(prob.M)},
- {"n", std::to_string(prob.N)},
- {"k", std::to_string(prob.K)}});
- auto srcs = get_headers_for_test();
- srcs.push_back({"main.cpp", src});
- rtc::compile_options options;
+ std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
+ auto&& solution = solutions[i];
+ auto src = ck::host::InterpolateString(gemm_compile_check,
+ {{"include", prob.GetIncludeHeader()},
+ {"template", solution.ToTemplateString()},
+ {"m", std::to_string(prob.M)},
+ {"n", std::to_string(prob.N)},
+ {"k", std::to_string(prob.K)}});
+ // auto srcs = get_headers_for_test();
+ // srcs.push_back({"main.cpp", src});
+ // rtc::compile_options options;
+ // options.kernel_name = "f";
+ rtc::hip_compile_options options;
options.kernel_name = "f";
- auto k = rtc::compile_kernel(srcs, options);
- auto block_size = solution.GetTemplateParameter("BlockSize");
- auto m_per_block = solution.GetTemplateParameter("MPerBlock");
- auto n_per_block = solution.GetTemplateParameter("NPerBlock");
- auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
+ options.additional_src_files = ck_headers();
+ // auto k = rtc::compile_kernel(srcs, options);
+ std::cout << src << std::endl;
+ auto k = rtc::compile_hip_code_object(src, options);
+ auto block_size = solution.GetTemplateParameter("BlockSize");
+ auto m_per_block = solution.GetTemplateParameter("MPerBlock");
+ auto n_per_block = solution.GetTemplateParameter("NPerBlock");
+ auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
ck::host::integer_divide_ceil(prob.N, n_per_block);
k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data());
@@ -187,4 +247,34 @@ TEST_CASE(test_problem_kernel)
}
}
+TEST_CASE(test_gemm_softmax_gemm)
+{
+ ck::host::device_batched_gemm_softmax_gemm::Problem prob;
+ prob.TransA = false;
+ prob.TransB = true;
+ prob.TransB1 = false;
+ prob.TransC = false;
+ prob.M = 1024;
+ prob.N = 1024;
+ prob.K = 1024;
+ prob.O = 1024;
+ check_all check;
+ auto a = to_gpu(generate_buffer(1024 * 1024, 0));
+ auto b = to_gpu(generate_buffer(1024 * 1024, 1));
+ auto b1 = to_gpu(generate_buffer(1024 * 1024, 2));
+ auto c = to_gpu(generate_buffer(1024 * 1024, 3));
+
+ std::string epilogue = "";
+ std::string prologue = "";
+
+ auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
+ std::cout << "Num solutions: " << solutions.size() << std::endl;
+
+ for(auto i = 0; i < solutions.size(); ++i) {
+ std::cout << "Solution " << i << std::endl;
+ std::cout << solutions[i].ToTemplateString() << std::endl;
+ std::cout << std::endl;
+ }
+}
+
int main(int argc, const char* argv[]) { test::run(argc, argv); }
diff --git a/codegen/test/rtc/include/rtc/compile_kernel.hpp b/codegen/test/rtc/include/rtc/compile_kernel.hpp
index 71db7be249..e832c42d4d 100644
--- a/codegen/test/rtc/include/rtc/compile_kernel.hpp
+++ b/codegen/test/rtc/include/rtc/compile_kernel.hpp
@@ -4,6 +4,7 @@
#include