diff --git a/CMakeLists.txt b/CMakeLists.txt index 1fe1bc91d5..e90f893de0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -92,6 +92,7 @@ endif() add_compile_options(-Wno-bit-int-extension) add_compile_options(-Wno-pass-failed) add_compile_options(-Wno-switch-default) +add_compile_options(-Wno-unique-object-duplication) if(DL_KERNELS) add_definitions(-DDL_KERNELS) diff --git a/Jenkinsfile b/Jenkinsfile index 835b7e724f..80392bfbed 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -117,7 +117,7 @@ def getDockerImage(Map conf=[:]){ { echo "Pulling down image: ${image}" retimage = docker.image("${image}") - withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.pull() } } @@ -148,7 +148,7 @@ def buildDocker(install_prefix){ //force building the new docker if that parameter is true echo "Building image: ${image_name}" retimage = docker.build("${image_name}", dockerArgs) - withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.push() } sh 'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi' @@ -162,7 +162,7 @@ def buildDocker(install_prefix){ catch(Exception ex){ echo "Unable to locate image: ${image_name}. Building image now" retimage = docker.build("${image_name}", dockerArgs + ' .') - withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.push() } } diff --git a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp new file mode 100644 index 0000000000..301df0a529 --- /dev/null +++ b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" +#include "ck/host/operation/gemm.hpp" +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// defines all values need for an instance of fwd conv +struct Operation_Xdl_CShuffle +{ + // returns a vector of instances, only given fusion operators: will use default problem spec + static std::vector> + CreateOperations(const std::string& prologue, const std::string& epilogue); + // returns a vector of instances, given a problem spec and fusion operators + static std::vector + CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue); + TensorDesc A{}; + TensorDesc B{}; + TensorDesc B1{}; + TensorDesc C{}; + DataType acc = DataType::Float; + DataType cs_type = DataType::Half; + std::string a_elem_op = PassThrough; + std::string b_elem_op = PassThrough; + std::string b1_elem_op = PassThrough; + std::string c_elem_op = PassThrough; + std::string acc_elem_op = Scale; + std::string prologue = ""; + std::string epilogue = ""; + std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default"; + // tuning parameters + operation::TileDescGemmGemm tile_desc{}; + operation::BlockTransferDesc a_block_transfer{}; + operation::BlockTransferDesc b0_block_transfer{}; + operation::BlockTransferDesc b1_block_transfer{}; + operation::CShuffleDesc cshuffle{}; + operation::CBlockTransferDesc c_block_transfer{}; + + bool mask_out_upper_triangle = false; + + // functions to update fusion operators if provided + void update_prologue(const std::string& prologue); + void update_epilogue(const std::string& epilogue); + /**constexpr**/ bool + IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_, std::size_t Gemm1NRaw_); + // returns a templated instance + Solution ToSolution() const; +}; + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp new file mode 100644 index 0000000000..428034a3ba --- /dev/null +++ b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// defines the problem specification for a GEMM operation +struct Problem +{ + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + std::size_t O = 0; + bool TransA = false; + bool TransB = false; + bool TransB1 = false; + bool TransC = false; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType B1DataType = DataType::Half; + DataType CDataType = DataType::Half; + std::string AElementOp = PassThrough; + std::string BElementOp = PassThrough; + std::string B1ElementOp = PassThrough; + std::string CElementOp = PassThrough; + std::string AccElementOp = Scale; + + // returns the correct device op file for the operation + std::string GetIncludeHeader() const; + + // returns a list of instances based on the problem spec and provided fusion operations + std::vector GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const; +}; + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp index 359da7d8cf..e5eeb6be15 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp @@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle operation::BlockTransferDesc b_block_transfer{}; operation::CShuffleDesc cshuffle{}; operation::CBlockTransferDesc c_block_transfer{}; + LoopScheduler loop_scheduler{}; + PipelineVersion pipeline_version{}; // functions to update fusion operators if provided void update_prologue(const std::string& prologue); diff --git a/codegen/include/ck/host/operation/gemm.hpp b/codegen/include/ck/host/operation/gemm.hpp index 84ef92f0a0..5a51a0002e 100644 --- a/codegen/include/ck/host/operation/gemm.hpp +++ b/codegen/include/ck/host/operation/gemm.hpp @@ -23,6 +23,26 @@ struct TileDesc int n_Xdl_per_wave = 0; int num_gemmk_prefetch_stage = 0; }; + +struct TileDescGemmGemm +{ + int block_size = 0; + int gemm01_m_per_block = 0; + int gemm0_n_per_block = 0; + int gemm0_k_per_block = 0; + int gemm1_n_per_block = 0; + int gemm1_k_per_block = 0; + int ak1 = 0; + int bk1 = 0; + int b1k1 = 0; + int m_per_XDL = 0; + int n_per_XDL = 0; + int gemm0_m_Xdl_per_wave = 0; + int gemm0_n_Xdl_per_wave = 0; + int gemm1_n_Xdl_per_wave = 0; + int num_gemmk_prefetch_stage = 0; +}; + struct BlockTransferDesc { std::string thread_cluster_length = ""; diff --git a/codegen/include/ck/host/types.hpp b/codegen/include/ck/host/types.hpp index 8bad7bf89c..b05e134176 100644 --- a/codegen/include/ck/host/types.hpp +++ b/codegen/include/ck/host/types.hpp @@ -66,6 +66,20 @@ enum class GemmType }; std::string ToString(GemmType gt); +enum class LoopScheduler +{ + Default, + Interwave, +}; +std::string ToString(LoopScheduler ls); + +enum class PipelineVersion +{ + v1, + v2 +}; +std::string ToString(PipelineVersion pv); + struct TensorDesc { DataType element; @@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...}); constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough"; constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear"; +constexpr const char* Scale = "ck::tensor_operation::element_wise::Scale"; } // namespace host } // namespace ck diff --git a/codegen/src/device_batched_gemm_softmax_gemm.cpp b/codegen/src/device_batched_gemm_softmax_gemm.cpp new file mode 100644 index 0000000000..cf140ead1d --- /dev/null +++ b/codegen/src/device_batched_gemm_softmax_gemm.cpp @@ -0,0 +1,38 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" +#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// return the relevant device op file based on the operation +std::string Problem::GetIncludeHeader() const +{ + return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"; +} + +// returns templated instances when provided with a problem specification +std::vector Problem::GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const +{ + if(get_xdlop_archs().count(arch) == 0) + return {}; + auto ops = ck::host::device_batched_gemm_softmax_gemm::Operation_Xdl_CShuffle::CreateOperations( + *this, prologue, epilogue); // obtains vector of instances + std::vector result; + std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { + return op.ToSolution(); // template instance with correct values + }); + return result; +} + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp new file mode 100644 index 0000000000..b12c2e1a4a --- /dev/null +++ b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp @@ -0,0 +1,408 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// calculate appropriate Gemm Specification based on input tensor dimensions +std::string GetGemmSpec(const std::size_t m, + const std::size_t n, + const std::size_t k, + const std::size_t n1, + const std::size_t m_per_block, + const std::size_t n_per_block, + const std::size_t k_per_block, + const std::size_t n1_per_block) +{ + std::string spec = ""; + if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0) + spec += "M"; + if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0) + spec += "N"; + if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0) + spec += "K"; + if(integer_divide_ceil(n1, n1_per_block) * n1_per_block - n1 != 0) + spec += "O"; + if(spec == "") + return "ck::tensor_operation::device::GemmSpecialization::Default"; + + return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; +} + +// function to update prologue/epilogue with user provided operation +void Operation_Xdl_CShuffle::update_prologue(const std::string& pro) +{ + if(!prologue.empty()) + { + this->prologue = pro; + } + else + { + this->prologue = ""; + } +} + +void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi) +{ + if(!epilogue.empty()) + { + this->epilogue = epi; + } + else + { + this->epilogue = ""; + } +} + +// accounts for all possible combinations of Row/Col major +static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } + +// Hard-code tuning parameters in modularized fashion, string them together into a vector of +// instances +std::vector Operation_Xdl_CShuffle::CreateOperations( + const Problem& prob, const std::string& prologue, const std::string& epilogue) +{ + std::vector result; + + std::vector tile_descriptions = { + // clang-format off +// Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| NumGemmK| +// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch| +// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage| +// | | | | | | | | | | | Wave| Wave| Wave| | + { 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1}, + { 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1}, + { 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1}, + { 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1}, + { 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, + { 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, + { 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, + { 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, + { 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, + { 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, + { 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, +// Padded fallback kernel + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, + { 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1}, +// Irregular k + { 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1}, + { 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1}, + { 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1}, + { 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1}, + { 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1}, + { 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1}, + // clang-format on + }; + + const std::vector a_block_descriptions = { + // clang-format off +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| +// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, +// Padded fallback kernel + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, +// Irregular k + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + // clang-format on + }; + + const std::vector b1_block_descriptions = { + // clang-format off +// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| +// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, +// Padded fallback kernel + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, +// Irregular k + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + // clang-format on + }; + + std::vector cshuffle_descriptions = { + // clang-format off +// CShuffle| CShuffle| +// MXdlPerWave| NXdlPerWave| +// PerShuffle| PerShuffle| +// | | + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 8}, + { 1, 4}, + { 1, 8}, + { 1, 4}, +// Padded fallback kernel + { 1, 2}, + { 1, 2}, +// Irregular k + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + // clang-format on + }; + + std::vector c_block_descriptions = { + // clang-format off +// CBlockTransferClusterLengths| CBlockTransfer +// _MBlock_MWaveMPerXdl| ScalarPerVector +// _NBlock_NWaveNPerXdl| _NWaveNPerXdl +// | + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1,16>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1,16>, 8}, + { S<1, 32, 1, 8>, 8}, +// Padded fallback kernel + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, +// Irregular k + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + // clang-format on + }; + + assert(tile_descriptions.size() == a_block_descriptions.size()); + assert(tile_descriptions.size() == b1_block_descriptions.size()); + assert(tile_descriptions.size() == cshuffle_descriptions.size()); + assert(tile_descriptions.size() == c_block_descriptions.size()); + + // Put all values together into a single operation > store into the result vector + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Xdl_CShuffle x; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b0_block_transfer = a_block_descriptions[i]; // b0 same as a + x.b1_block_transfer = b1_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; + x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; + x.B1 = TensorDesc{prob.B1DataType, ToLayout(prob.TransB1)}; + x.C = TensorDesc{prob.CDataType, ToLayout(prob.TransC)}; + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.b1_elem_op = prob.B1ElementOp; + x.c_elem_op = prob.CElementOp; + x.acc_elem_op = prob.AccElementOp; + x.gemm_specialization = GetGemmSpec(prob.M, + prob.N, + prob.K, + prob.O, + x.tile_desc.gemm01_m_per_block, + x.tile_desc.gemm0_n_per_block, + x.tile_desc.gemm0_k_per_block, + x.tile_desc.gemm1_n_per_block); + x.update_prologue(prologue); + x.update_epilogue(epilogue); + x.mask_out_upper_triangle = true; + result.push_back(x); + + x.mask_out_upper_triangle = false; + result.push_back(x); + } + return result; +} + +// set up instances when not provided with a problem specification, use default operation values and +// all possible layout combinations +std::vector> +Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue) +{ + Problem prob; + prob.TransA = false; + prob.TransB = true; + prob.TransB1 = false; + prob.TransC = false; + + return {CreateOperations(prob, prologue, epilogue)}; +} + +static const char* const DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate = + "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<${LayoutA}, " + "${LayoutB0}, ${LayoutB1}, ${LayoutC}, ${ADataType}, ${B0DataType}, ${B1DataType}, " + "${CDataType}, ${AccDataType}, ${CShuffleDataType}, ${AElementwiseOperation}, " + "${B0ElementwiseOperation}, ${Acc0ElementwiseOperation}, ${B1ElementwiseOperation}, " + "${CElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, " + "${Gemm01MPerBlock}, ${Gemm0NPerBlock}, ${Gemm0KPerBlock}, ${Gemm1NPerBlock}, " + "${Gemm1KPerBlock}, ${AK1}, ${BK1}, ${B1K1}, ${MPerXDL}, ${NPerXDL}, ${Gemm0MXdlPerWave}, " + "${Gemm0NXdlPerWave}, ${Gemm1NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, " + "${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, " + "${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, " + "${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, " + "${B0BlockTransferThreadClusterLengths_BK0_N_BK1}, " + "${B0BlockTransferThreadClusterArrangeOrder}, ${B0BlockTransferSrcAccessOrder}, " + "${B0BlockTransferSrcVectorDim}, ${B0BlockTransferSrcScalarPerVector}, " + "${B0BlockTransferDstScalarPerVector_BK1}, ${B0BlockLdsExtraN}, " + "${B1BlockTransferThreadClusterLengths_BK0_N_BK1}, " + "${B1BlockTransferThreadClusterArrangeOrder}, ${B1BlockTransferSrcAccessOrder}, " + "${B1BlockTransferSrcVectorDim}, ${B1BlockTransferSrcScalarPerVector}, " + "${B1BlockTransferDstScalarPerVector_BK1}, ${B1BlockLdsExtraN}, " + "${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " + "${CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl}, " + "${CBlockTransferScalarPerVector_NWaveNPerXdl}, ${MaskOutUpperTriangle}>"; + +// use hardcoded instances from vector of operations to substitute values into instance template +Solution Operation_Xdl_CShuffle::ToSolution() const +{ + std::unordered_map values = { + {"name", + std::to_string(this->tile_desc.block_size) + "_" + + std::to_string(this->tile_desc.gemm01_m_per_block) + "_" + + std::to_string(this->tile_desc.gemm0_n_per_block) + "_" + + std::to_string(this->tile_desc.gemm0_k_per_block) + "_" + + std::to_string(this->tile_desc.gemm1_n_per_block) + "_" + + std::to_string(this->tile_desc.gemm1_k_per_block) + "_" + + std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" + + std::to_string(this->tile_desc.b1k1) + "_" + + std::to_string(this->tile_desc.m_per_XDL) + "_" + + std::to_string(this->tile_desc.n_per_XDL) + "_" + + std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)}, + {"LayoutA", ToString(this->A.layout)}, + {"LayoutB0", ToString(this->B.layout)}, + {"LayoutB1", ToString(this->B1.layout)}, + {"LayoutC", ToString(this->C.layout)}, + {"ADataType", ToString(this->A.element)}, + {"B0DataType", ToString(this->B.element)}, + {"B1DataType", ToString(this->B1.element)}, + {"CDataType", ToString(this->C.element)}, + {"AccDataType", ToString(this->acc)}, + {"CShuffleDataType", ToString(this->cs_type)}, + {"AElementwiseOperation", this->a_elem_op}, + {"B0ElementwiseOperation", this->b_elem_op}, + {"Acc0ElementwiseOperation", this->acc_elem_op}, + {"B1ElementwiseOperation", this->b1_elem_op}, + {"CElementwiseOperation", this->c_elem_op}, + {"GemmSpecialization", this->gemm_specialization}, + {"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)}, + {"BlockSize", std::to_string(this->tile_desc.block_size)}, + {"Gemm01MPerBlock", std::to_string(this->tile_desc.gemm01_m_per_block)}, + {"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)}, + {"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)}, + {"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)}, + {"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)}, + {"AK1", std::to_string(this->tile_desc.ak1)}, + {"BK1", std::to_string(this->tile_desc.bk1)}, + {"B1K1", std::to_string(this->tile_desc.b1k1)}, + {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)}, + {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)}, + {"Gemm0MXdlPerWave", std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave)}, + {"Gemm0NXdlPerWave", std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave)}, + {"Gemm1NXdlPerWave", std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)}, + {"ABlockTransferThreadClusterLengths_AK0_M_AK1", + this->a_block_transfer.thread_cluster_length}, + {"ABlockTransferThreadClusterArrangeOrder", + this->a_block_transfer.thread_cluster_arrange_order}, + {"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order}, + {"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)}, + {"ABlockTransferSrcScalarPerVector", + std::to_string(this->a_block_transfer.src_scalar_per_vector)}, + {"ABlockTransferDstScalarPerVector_AK1", + std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)}, + {"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)}, + {"B0BlockTransferThreadClusterLengths_BK0_N_BK1", + this->b0_block_transfer.thread_cluster_length}, + {"B0BlockTransferThreadClusterArrangeOrder", + this->b0_block_transfer.thread_cluster_arrange_order}, + {"B0BlockTransferSrcAccessOrder", this->b0_block_transfer.src_access_order}, + {"B0BlockTransferSrcVectorDim", std::to_string(this->b0_block_transfer.src_vec_dim)}, + {"B0BlockTransferSrcScalarPerVector", + std::to_string(this->b0_block_transfer.src_scalar_per_vector)}, + {"B0BlockTransferDstScalarPerVector_BK1", + std::to_string(this->b0_block_transfer.dst_scalar_per_vector_k1)}, + {"B0BlockLdsExtraN", std::to_string(this->b0_block_transfer.lds_add_extra_dim)}, + {"B1BlockTransferThreadClusterLengths_BK0_N_BK1", + this->b1_block_transfer.thread_cluster_length}, + {"B1BlockTransferThreadClusterArrangeOrder", + this->b1_block_transfer.thread_cluster_arrange_order}, + {"B1BlockTransferSrcAccessOrder", this->b1_block_transfer.src_access_order}, + {"B1BlockTransferSrcVectorDim", std::to_string(this->b1_block_transfer.src_vec_dim)}, + {"B1BlockTransferSrcScalarPerVector", + std::to_string(this->b1_block_transfer.src_scalar_per_vector)}, + {"B1BlockTransferDstScalarPerVector_BK1", + std::to_string(this->b1_block_transfer.dst_scalar_per_vector_k1)}, + {"B1BlockLdsExtraN", std::to_string(this->b1_block_transfer.lds_add_extra_dim)}, + {"CShuffleMXdlPerWavePerShuffle", + std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)}, + {"CShuffleNXdlPerWavePerShuffle", + std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)}, + {"CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl", + this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, + {"CBlockTransferScalarPerVector_NWaveNPerXdl", + std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + {"MaskOutUpperTriangle", std::to_string(this->mask_out_upper_triangle)}, + }; + + return Solution{InterpolateString(DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate, values), + std::move(values)}; +} + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp index fff75c1962..fe556615e0 100644 --- a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp @@ -62,6 +62,12 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi) // accounts for all possible combinations of Row/Col major static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } +// clang-format off +// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, + +// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +// clang-format on + // Hard-code tuning parameters in modularized fashion, string them together into a vector of // instances std::vector Operation_Xdl_CShuffle::CreateOperations( @@ -83,6 +89,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1}, { 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1}, { 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1}, +// Irregular tile + { 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1}, // clang-format on }; @@ -100,6 +108,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, // clang-format on }; @@ -109,15 +119,17 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( // ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| // Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | // | | | | | | | + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, // clang-format on - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, - {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, }; std::vector b_block_descriptions_rowmajor = { @@ -134,6 +146,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, // clang-format on }; @@ -151,6 +165,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, // clang-format on }; @@ -167,6 +183,7 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { 1, 1}, { 1, 1}, { 1, 1}, + { 1, 1}, { 1, 1}, // clang-format on }; @@ -185,6 +202,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<1, 16, 1, 8>, 8}, { S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8}, +// Irregular tile + { S<1, 16, 1, 4>, 1}, // clang-format on }; @@ -199,33 +218,44 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( assert(tile_descriptions.size() == cshuffle_descriptions.size()); assert(tile_descriptions.size() == c_block_descriptions.size()); - // Put all values together into a single operation > store into the result vector - for(std::size_t i = 0; i < tile_descriptions.size(); i++) + const std::vector> scheduler_pipeline_descriptions = + { + {LoopScheduler::Default, PipelineVersion::v1}, + {LoopScheduler::Interwave, PipelineVersion::v1}, + {LoopScheduler::Default, PipelineVersion::v2}, + }; + for(auto [loop_scheduler, pipeline_version] : scheduler_pipeline_descriptions) { - Operation_Xdl_CShuffle x; - x.tile_desc = tile_descriptions[i]; - x.a_block_transfer = a_block_descriptions[i]; - x.b_block_transfer = b_block_descriptions[i]; - x.cshuffle = cshuffle_descriptions[i]; - x.c_block_transfer = c_block_descriptions[i]; - x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; - x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; - x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)}; - x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) { - return TensorDesc{dt, ToLayout(trans)}; - }); - x.a_elem_op = prob.AElementOp; - x.b_elem_op = prob.BElementOp; - x.cde_elem_op = prob.CDEElementOp; - x.gemm_specialization = GetGemmSpec(prob.M, - prob.N, - prob.K, - x.tile_desc.m_per_block, - x.tile_desc.n_per_block, - x.tile_desc.k_per_block); - x.update_prologue(prologue); - x.update_epilogue(epilogue); - result.push_back(x); + // Put all values together into a single operation > store into the result vector + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Xdl_CShuffle x; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b_block_transfer = b_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; + x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; + x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)}; + x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) { + return TensorDesc{dt, ToLayout(trans)}; + }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; + x.gemm_specialization = GetGemmSpec(prob.M, + prob.N, + prob.K, + x.tile_desc.m_per_block, + x.tile_desc.n_per_block, + x.tile_desc.k_per_block); + x.loop_scheduler = loop_scheduler; + x.pipeline_version = pipeline_version; + x.update_prologue(prologue); + x.update_epilogue(epilogue); + result.push_back(x); + } } return result; } @@ -263,7 +293,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = "${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, " "${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " "${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, " - "${CDEBlockTransferScalarPerVector_NPerBlock}>"; + "${CDEBlockTransferScalarPerVector_NPerBlock}, ${LoopScheduler}, ${PipelineVersion}>"; // use hardcoded instances from vector of operations to substitute values into instance template Solution Operation_Xdl_CShuffle::ToSolution() const @@ -336,6 +366,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, {"CDEBlockTransferScalarPerVector_NPerBlock", std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + {"LoopScheduler", ToString(this->loop_scheduler)}, + {"PipelineVersion", ToString(this->pipeline_version)}, }; return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values), diff --git a/codegen/src/types.cpp b/codegen/src/types.cpp index 9aa5d39fae..a60e36ca4a 100644 --- a/codegen/src/types.cpp +++ b/codegen/src/types.cpp @@ -59,6 +59,26 @@ std::string ToString(GemmType gt) throw std::runtime_error("Incorrect gemm type"); } +std::string ToString(LoopScheduler ls) +{ + switch(ls) + { + case LoopScheduler::Default: return "ck::LoopScheduler::Default"; + case LoopScheduler::Interwave: return "ck::LoopScheduler::Interwave"; + } + throw std::runtime_error("Incorrect LoopScheduler type"); +} + +std::string ToString(PipelineVersion pv) +{ + switch(pv) + { + case PipelineVersion::v1: return "ck::PipelineVersion::v1"; + case PipelineVersion::v2: return "ck::PipelineVersion::v2"; + } + throw std::runtime_error("Incorrect PipelineVersion type"); +} + std::string SequenceStr(const std::vector& v) { return "ck::Sequence<" + diff --git a/codegen/test/rtc/include/rtc/hip.hpp b/codegen/test/rtc/include/rtc/hip.hpp index af2f4a9122..3163bb08ed 100644 --- a/codegen/test/rtc/include/rtc/hip.hpp +++ b/codegen/test/rtc/include/rtc/hip.hpp @@ -8,6 +8,7 @@ #include #include #include +#include namespace rtc { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 83a1e82d6d..c05660c8ab 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -506,6 +506,14 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= deterministic == "f" if not cond: continue + if receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'bias'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: + continue api_pool.register_dq_dk_dv_traits(k.api_trait()) gen.append(k) @@ -801,4 +809,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") \ No newline at end of file + f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 1c9d743f3d..ad8daba17e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -487,13 +487,20 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if kernel_filter != None: if not fnmatch.fnmatch(k.name, kernel_filter): continue - if receipt == 2: + if receipt in (2, 3): cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_squant == 'f' if not cond: continue + if receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 5b1b6664cc..a0fb42aa11 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -103,7 +103,8 @@ if __name__ == "__main__": required=False, help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ " 1: generate more instance to cover all hdim\n" + \ - " 2: Only generate instance for Flash attention integration" + " 2: Only generate instance for Flash attention integration\n" + \ + " 4: Only generate instance for PyTorch integration" ) args = parser.parse_args() diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 2e04780eb0..5dc7b9cd0b 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -82,8 +82,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index 5fa94f5f72..ed02f89fac 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 028f8a44c3..13a1c30e43 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template +template float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf, @@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = gemm_calc( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + float ave_time = + gemm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_byte = @@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C - << " A_Layout =" << ALayout::name - << " B_Layout =" << BLayout::name - << " C_Layout =" << CLayout::name - << " A Type = " << DataTypeTraits::name - << " B Type = " << DataTypeTraits::name - << " C Type = " << DataTypeTraits::name - << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << std::endl; + << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits::name + << " B Type = " << DataTypeTraits::name + << " C Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; return ave_time; } @@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc, if(!result) return -1; - using ADataType = typename GemmBasicTypeConfig::ADataType; - using BDataType = typename GemmBasicTypeConfig::BDataType; - using CDataType = typename GemmBasicTypeConfig::CDataType; - using AccDataType = typename GemmBasicTypeConfig::AccDataType; + using ADataType = typename GemmBasicTypeConfig::ADataType; + using BDataType = typename GemmBasicTypeConfig::BDataType; + using CDataType = typename GemmBasicTypeConfig::CDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); @@ -133,19 +135,19 @@ int run_gemm_example_with_layouts(int argc, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_gemm(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - kbatch, - n_warmup, - n_repeat); + invoke_gemm( + a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; @@ -160,9 +162,9 @@ int run_gemm_example_with_layouts(int argc, a_m_k, b_k_n, c_m_n_host_ref); const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol - (K, kbatch, max_accumulated_value); - pass = ck_tile::check_err(c_m_n_dev_result, + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), @@ -171,7 +173,7 @@ int run_gemm_example_with_layouts(int argc, std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; - std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; } else if(arg_parser.get_int("v") == 2) { @@ -218,9 +220,9 @@ int run_gemm_example_with_layouts(int argc, c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); const float max_accumulated_value = *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol - (K, kbatch, max_accumulated_value); - pass = ck_tile::check_err(c_m_n_dev_result, + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), @@ -229,7 +231,7 @@ int run_gemm_example_with_layouts(int argc, std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; - std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; } return pass; diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index d2c4df1058..c4faa35e33 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[]) .insert("k", "4", "topk") .insert("unit", "32", "unit_size") .insert("moe_buf_size", "0", "moe_buf_size") + .insert("local_eid", + "-1", + "a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n" + "please make sure eid is in ascending order!") .insert("seed", "-1", "seed to be used, -1 means random every time") .insert("kname", "0", "when set to 1 it will print kernel name") .insert("warmup", "5", "number of iterations before benchmark the kernel") @@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) int kname = args.get_int("kname"); int warmup = args.get_int("warmup"); int repeat = args.get_int("repeat"); + int max_output_ids = ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size); @@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args) return false; } + bool local_expert_masking = args.get_str("local_eid") != "-1"; + auto local_expert_masking_host = [&]() { + if(local_expert_masking) + { + auto local_eid = args.get_int_vec("local_eid"); + // std::vector v_ {num_experts, 0}; + ck_tile::HostTensor v_{{num_experts}}; + v_.SetZero(); + for(auto eid : local_eid) + { + if(eid >= num_experts) + { + throw std::runtime_error( + "local_eid larger than number of expert, please check"); + } + v_.mData[eid] = 1; + } + return v_; + } + else + // return std::vector{}; + return ck_tile::HostTensor{{1}}; + }(); + // tokens already considered batch size ck_tile::HostTensor topk_ids_host({tokens, topk}, {topk, 1}); ck_tile::HostTensor weights_host({tokens, topk}, {topk, 1}); @@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args) sorted_expert_ids_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem local_expert_masking_dev( + local_expert_masking_host.get_element_space_size_in_bytes()); topk_ids_dev.ToDevice(topk_ids_host.data()); weights_dev.ToDevice(weights_host.data()); @@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args) { moe_buf_dev.ToDevice(moe_buf_host.data()); } + if(local_expert_masking) + local_expert_masking_dev.ToDevice(local_expert_masking_host.data()); - moe_sorting_trait trait{index_prec, weight_prec}; + moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking}; moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(), + local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() + : nullptr, sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(), sorted_expert_ids_dev.GetDeviceBuffer(), @@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args) warmup, repeat}; auto ms = moe_sorting(trait, karg, sc); - printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ", + printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ", index_prec.c_str(), weight_prec.c_str(), tokens, num_experts, - topk, - ms); + topk); + + if(local_expert_masking) + { + printf("local_eid:%s, ", args.get_str("local_eid").c_str()); + } + if(ms < 0) printf("not supported\n"); + else + printf("ms:%f, ", ms); fflush(stdout); if(ms < 0) { @@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args) int32_t ref_total_tokens_post_pad = 0; ck_tile::reference_moe_sorting(topk_ids_host, weights_host, + local_expert_masking_host, sorted_ids_ref, sorted_weights_ref, sorted_expert_ids_ref, ref_total_tokens_post_pad, num_experts, - unit_size); + unit_size, + local_expert_masking); rtn &= ck_tile::check_err( sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6); rtn &= ck_tile::check_err(sorted_weights_host, @@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args) moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0); } rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]; + printf("total_tokens_post_pad:%d(%d), ", + ref_total_tokens_post_pad, + sorted_id_cnt_host.mData[0]); } - printf("valid:%s\n", rtn ? "y" : "n"); + printf("valid:%s", rtn ? "y" : "n"); + fflush(stdout); + if(!rtn) + printf(", (%d)", seed); + printf("\n"); fflush(stdout); return rtn; } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 723fb3f69f..abff24a669 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -3,6 +3,12 @@ #include "moe_sorting_api.hpp" +#ifndef MOE_SORTING_USE_EX_KERNEL +#define MOE_SORTING_USE_EX_KERNEL 1 +#endif + +#if !MOE_SORTING_USE_EX_KERNEL + #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \ @@ -17,6 +23,67 @@ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ return ave_time; +#else + +#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \ + constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ + constexpr bool sub_token_onshot = sub_token_onshot_; \ + constexpr bool local_expert_masking = local_expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemEx; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \ + if(row_ % 8 == 0) \ + { \ + MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \ + } \ + else if(row_ % 4 == 0) \ + { \ + MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \ + } \ + else if(row_ % 2 == 0) \ + { \ + MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \ + } + +#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \ + if(is_sub_token_onshot) \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \ + } + +#define MOE_SORTING_DISPATCH_EMASK_(row_) \ + if(is_local_expert_masking) \ + { \ + MOE_SORTING_DISPATCH_SUBTO_(row_, true) \ + } \ + else \ + { \ + MOE_SORTING_DISPATCH_SUBTO_(row_, false) \ + } + +#endif + +#if !MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_DISPATCH(unroll_num_) \ if(a.num_experts <= 8) \ { \ @@ -38,11 +105,13 @@ { \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ } +#endif float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) { if(t.weight_type == "fp32" && t.index_type == "int32") { +#if !MOE_SORTING_USE_EX_KERNEL if(a.num_experts > 127) { printf("lds size exceed, only support experts <127 \n"); @@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi MOE_SORTING_DISPATCH(4); } } +#else + using index_t = ck_tile::index_t; + using ms_weight_type = float; + auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts); + auto sub_token_ = r_ - 2; + r_ = (r_ - 2) / 8; + bool is_sub_token_onshot = a.tokens <= sub_token_; + bool is_local_expert_masking = t.local_expert_masking; + (void)c_; + + MOE_SORTING_DISPATCH_EMASK_(r_); + // MOE_SORTING_DISPATCH_ETILE(0, 0); +#endif } return -1; } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp index 0cb393f7de..5bda4d368a 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp @@ -10,7 +10,8 @@ struct moe_sorting_trait { std::string index_type; - std::string weight_type; // currently always float + std::string weight_type; // currently always float + bool local_expert_masking; // if mask experts as local expert }; struct moe_sorting_args : public ck_tile::MoeSortingHostArgs diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh index 3ff8a7332d..cf2c2e164b 100644 --- a/example/ck_tile/13_moe_sorting/script/smoke_test.sh +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11 $EXE -t=1 -e=1 -k=1 $EXE -t=99 -e=2 -k=1 $EXE -t=333 -e=99 -k=13 +$EXE -t=11 -e=256 -k=5 +$EXE -t=64 -e=455 -k=8 +$EXE -t=777 -e=802 -k=99 +$EXE -t=4097 -e=906 -k=51 $EXE -t=128 -e=32 -k=5 -moe_buf_size=262144 +$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11 +$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19 +$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33 +$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129 diff --git a/example/ck_tile/15_fused_moe/README.md b/example/ck_tile/15_fused_moe/README.md index b6ceabf351..089e1de78e 100644 --- a/example/ck_tile/15_fused_moe/README.md +++ b/example/ck_tile/15_fused_moe/README.md @@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator: // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // -// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) +// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated) // * this could be larger than actual, since actual tokens are on GPU // // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 7ca24c5c9a..805cd54878 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -3,6 +3,12 @@ #include "fused_moesorting.hpp" +#ifndef MOE_SORTING_USE_EX_KERNEL +#define MOE_SORTING_USE_EX_KERNEL 1 +#endif + +#if !MOE_SORTING_USE_EX_KERNEL + #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \ @@ -17,6 +23,24 @@ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ return ave_time; +#else +#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \ + constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \ + constexpr bool sub_token_onshot = sub_token_onshot_; \ + using ms_problem = \ + ck_tile::MoeSortingProblemEx; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +#endif + +#if !MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_DISPATCH(unroll_num_) \ if(a.num_experts <= 8) \ { \ @@ -38,11 +62,13 @@ { \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ } +#endif float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) { if(t.weight_type == "fp32" && t.index_type == "int32") { +#if !MOE_SORTING_USE_EX_KERNEL if(a.num_experts > 127) { printf("lds size exceed, only support experts <127 \n"); @@ -83,6 +109,54 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til MOE_SORTING_DISPATCH(4); } } +#else + using index_t = ck_tile::index_t; + using ms_weight_type = float; + auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts); + auto sub_token_ = r_ - 2; + r_ = (r_ - 2) / 8; + bool is_sub_token_onshot = a.tokens <= sub_token_; + (void)c_; + if(is_sub_token_onshot) + { + if(r_ % 8 == 0) + { + MOE_SORTING_DISPATCH_(8, true); + } + else if(r_ % 4 == 0) + { + MOE_SORTING_DISPATCH_(4, true); + } + else if(r_ % 2 == 0) + { + MOE_SORTING_DISPATCH_(2, true); + } + else + { + MOE_SORTING_DISPATCH_(1, true); + } + } + else + { + if(r_ % 8 == 0) + { + MOE_SORTING_DISPATCH_(8, false); + } + else if(r_ % 4 == 0) + { + MOE_SORTING_DISPATCH_(4, false); + } + else if(r_ % 2 == 0) + { + MOE_SORTING_DISPATCH_(2, false); + } + else + { + MOE_SORTING_DISPATCH_(1, false); + } + } + // MOE_SORTING_DISPATCH_ETILE(0, 0); +#endif } return -1; } diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index 949621e116..286fe4201d 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -79,8 +79,11 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index d0df8845cc..1105304e3e 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -212,7 +212,7 @@ int run_batched_gemm_example_with_layouts(int argc, << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; - std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; } else if(arg_parser.get_int("v") == 2) { diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index c32fac6c0d..03d5818179 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -118,7 +118,7 @@ float grouped_gemm(const std::vector& gemm_descs, if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" + std::cout << "Launching kernel: " << GroupedGemmKernel::GetName() << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index b0a3e9973c..080ea818c9 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -202,7 +202,7 @@ int run_grouped_gemm_example_with_layouts(int argc, << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; } - std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; } return pass; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index bfbcebd7c8..ea5a5d0e16 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -610,6 +610,96 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return true; } + static constexpr bool + IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_) + { + // check vector load/store + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v) + { + if(KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of B + if constexpr(is_same_v) + { + if(NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of B1 + if constexpr(is_same_v) + { + if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of C + if constexpr(is_same_v) + { + if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + + return true; + } + static bool IsSupportedArgument(const Argument& arg) { if(!ck::is_xdl_supported()) @@ -624,29 +714,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; - // Check scalar per vector requirement - const auto a_extent_lowest = - is_same_v ? KRaw : MRaw; - const auto b_extent_lowest = - is_same_v ? NRaw : KRaw; - const auto b1_extent_lowest = - is_same_v ? Gemm1NRaw : NRaw; - const auto c_extent_lowest = - is_same_v ? Gemm1NRaw : MRaw; - - if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && - b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && - b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && - c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) - { - return false; - } - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + arg.block_2_ctile_map_) and + IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw); } // polymorphic @@ -764,6 +837,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle return str.str(); } + + template + struct Descriptor + { + template + static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc) + { + const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc) + { + const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc) + { + const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc); + + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc) + { + return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc); + } + + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using B1GridDesc_BK0_N_BK1 = + remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + B1GridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1, + BK1, + B1K1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + true, + BBlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + matrix_padder.PadN, + MaskOutUpperTriangle>; + + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1; + CGridDesc_M_N c_grid_desc_m_n; + C0MatrixMask c0_matrix_mask; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_descriptor_mblock_mperblock_nblock_nperblock; + + // element-wise op + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + B1ElementwiseOperation b1_element_op; + CElementwiseOperation c_element_op; + + bool has_main_k_block_loop = true; + bool is_valid = false; + + constexpr Descriptor(ADesc a, + BDesc b, + B1Desc b1, + CDesc c, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + B1ElementwiseOperation b1_element_op_, + CElementwiseOperation c_element_op_) + : a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)}, + b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)}, + b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)}, + c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)}, + block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)}, + c_grid_descriptor_mblock_mperblock_nblock_nperblock{ + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n)}, + has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( + a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, + c0_matrix_mask{c.GetLength(I1)}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + b1_element_op{b1_element_op_}, + c_element_op{c_element_op_}, + is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_m_n, + block_2_ctile_map) and + IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1), + b_grid_desc_bk0_n_bk1.GetLength(I1), + a_grid_desc_ak0_m_ak1.GetLength(I0) * + a_grid_desc_ak0_m_ak1.GetLength(I2), + b1_grid_desc_bk0_n_bk1.GetLength(I1))} + { + } + + constexpr bool IsValid() const { return is_valid; } + }; + + template + static constexpr auto + make_descriptor(ADesc a, + BDesc b, + B1Desc b1, + CDesc c, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{}, + CElementwiseOperation c_element_op = CElementwiseOperation{}) + { + return Descriptor( + a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op); + } + + template + __device__ static void Run(const Desc& desc, + const float scale, + const ADataType* __restrict__ p_a_grid, + const ADataType* __restrict__ p_b_grid, + const ADataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid) + { +#ifndef __HIPCC_RTC__ + assert(desc.is_valid); +#endif + __shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()]; + AccElementwiseOperation acc_element_op{scale}; + + if(desc.has_main_k_block_loop) + { + Desc::GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_b1_grid, + p_c_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + acc_element_op, + desc.b1_element_op, + desc.c_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.b1_grid_desc_bk0_n_bk1, + desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, + desc.block_2_ctile_map, + desc.c0_matrix_mask); + } + else + { + Desc::GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_b1_grid, + p_c_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + acc_element_op, + desc.b1_element_op, + desc.c_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.b1_grid_desc_bk0_n_bk1, + desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, + desc.block_2_ctile_map, + desc.c0_matrix_mask); + } + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index b4cf996a48..795995d9a3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -1495,10 +1495,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle // if workspace is not allocated if(!arg.p_workspace_) { - std::cerr << "Warning: Workspace for " - "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not " - "allocated, use SetWorkSpacePointer." - << std::endl; + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Warning: Workspace for " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + } return false; } if(!ck::is_xdl_supported()) diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index ba4f4b6e7d..a8c95b9c38 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -27,12 +27,12 @@ #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/int8.hpp" -#include "ck_tile/core/numeric/pk_int4.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/null_type.hpp" #include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/numeric/pk_int4.hpp" #include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/tensor/buffer_view.hpp" diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 39a904717c..5a5e01460f 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -5,6 +5,7 @@ #include "ck_tile/host/arg_parser.hpp" #include "ck_tile/host/check_err.hpp" +#include "ck_tile/host/concat.hpp" #include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp" #include "ck_tile/host/convolution_parameter.hpp" #include "ck_tile/host/device_memory.hpp" diff --git a/include/ck_tile/host/concat.hpp b/include/ck_tile/host/concat.hpp new file mode 100644 index 0000000000..c68b908149 --- /dev/null +++ b/include/ck_tile/host/concat.hpp @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" + +namespace ck_tile { + +template +struct IsCharArray : std::false_type +{ +}; + +template +struct IsCharArray : std::true_type +{ +}; + +template +struct IsCharArray : std::true_type +{ +}; + +template +struct IsCharArray : std::true_type +{ +}; + +template +struct IsCharArray : std::true_type +{ +}; + +template +inline constexpr bool AllConvertibleToStringView = ((std::is_convertible_v || + IsCharArray::value || + std::is_same_v)&&...); + +template +[[nodiscard]] auto concat(const Ts&... xs) + -> std::enable_if_t, std::string> +{ + using ::operator<<; + thread_local std::ostringstream oss; + oss.str(""); + + (oss << ... << xs); + return oss.str(); +} + +template +[[nodiscard]] constexpr inline std::size_t getSize(char (&)[N]) noexcept +{ + return N; +} + +template +[[nodiscard]] constexpr inline std::size_t getSize(const char (&)[N]) noexcept +{ + return N; +} + +[[nodiscard]] constexpr inline std::size_t getSize(const char* s) noexcept +{ + const char* end = s; + while(*end++ != 0) {} + return end - s - 1; +} + +[[nodiscard]] constexpr inline std::size_t getSize(const char&) noexcept { return 1; } + +[[nodiscard]] inline std::size_t getSize(const std::string& s) noexcept { return s.size(); } + +[[nodiscard]] constexpr inline std::size_t getSize(const std::string_view& s) noexcept +{ + return s.size(); +} + +template +auto concatInto(std::string& result, const Ts&... xs) + -> std::enable_if_t, void> +{ + const std::size_t space = (1 + ... + getSize(xs)); + result.reserve(result.size() + space); + ((result += xs), ...); +} + +template +[[nodiscard]] auto concat(const Ts&... xs) + -> std::enable_if_t, std::string> +{ + std::string result; + concatInto(result, xs...); + return result; +} + +// Function for types convertible to std::string_view +template +[[nodiscard]] auto concat(Sep sep, const First& first, const Rest&... rest) + -> std::enable_if_t, std::string> +{ + std::string result; + result += first; + ((result += sep, result += rest), ...); + return result; +} + +// Function for other types +template +[[nodiscard]] auto concat(Sep sep, const First& first, const Rest&... rest) + -> std::enable_if_t, std::string> +{ + using ::operator<<; + thread_local std::ostringstream oss; + oss.str(""); + oss << first; + ((oss << sep << rest), ...); + return oss.str(); +} + +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_moe_sorting.hpp b/include/ck_tile/host/reference/reference_moe_sorting.hpp index 3851629cc2..47f0ba576b 100644 --- a/include/ck_tile/host/reference/reference_moe_sorting.hpp +++ b/include/ck_tile/host/reference/reference_moe_sorting.hpp @@ -14,12 +14,15 @@ namespace ck_tile { template CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, const HostTensor& weights, + const HostTensor& local_expert_mask, HostTensor& p_sorted_token_ids, HostTensor& sorted_weight, HostTensor& sorted_expert_ids, index_t& unit_cnt, const index_t experts, - const index_t unit_size) + const index_t unit_size, + bool local_expert_masking, + bool skip_experts_with_zero_token = true) { const index_t num_token = topk_ids.mDesc.get_lengths()[0]; const index_t topk = topk_ids.mDesc.get_lengths()[1]; @@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, #endif std::vector> expert_token_weights( experts, std::vector(unit_size, 0)); + // count number of unit-size slices in this expert std::vector expert_slices(experts, 1); + // count the tokens used in this expert std::vector expert_slice_idxs(experts, 0); + // TODO: above 2 buffer seems duplicated for(index_t t = 0; t < num_token; t++) { @@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, IndexType* out_tokens = p_sorted_token_ids.data(); WeightType* out_weights = sorted_weight.data(); IndexType* out_expert_id = sorted_expert_ids.data(); + int curr_expert_id = 0; for(index_t e = 0; e < experts; e++) { + if(local_expert_masking) + { + if(local_expert_mask(e) == 0) + continue; + } + if(skip_experts_with_zero_token) + { + if(expert_slice_idxs[e] == 0) + { + curr_expert_id++; + continue; + } + } + memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size); out_tokens += expert_slices[e] * unit_size; memcpy(out_weights, @@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, for(index_t s = 0; s < expert_slices[e]; s++) { - out_expert_id[s] = e; + out_expert_id[s] = curr_expert_id; unit_cnt++; } out_expert_id += expert_slices[e]; + curr_expert_id++; } unit_cnt *= unit_size; return; diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index 8b5302257c..1768c802d5 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -10,3 +10,4 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index ade2f18041..200e2a618c 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -9,3 +9,4 @@ #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 9b9bf30ad3..027e2fdd94 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -5,3 +5,4 @@ #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/common/utils.hpp b/include/ck_tile/ops/common/utils.hpp new file mode 100644 index 0000000000..8592f93e0f --- /dev/null +++ b/include/ck_tile/ops/common/utils.hpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// clang-format off +template struct typeToStr; +template <> struct typeToStr { static constexpr const char * name = "fp32"; }; +template <> struct typeToStr { static constexpr const char * name = "fp16"; }; +template <> struct typeToStr { static constexpr const char * name = "bf16"; }; +template <> struct typeToStr { static constexpr const char * name = "fp8"; }; +template <> struct typeToStr { static constexpr const char * name = "bf8"; }; +template <> struct typeToStr { static constexpr const char * name = "int8"; }; +// clang-format on + +template +std::string gemm_prec_str() +{ + std::string base_str = std::string(typeToStr::name); + if(!std::is_same_v) + { + base_str += "_" + std::string(typeToStr::name); + } + return base_str; +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 15fa269740..53187771b9 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -6,3 +6,4 @@ #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 95ead2645e..9d2ed407c9 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -8,3 +8,4 @@ #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 616db2fa5b..82f6d48eda 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -9,3 +9,4 @@ #include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 4cbb59e95b..c896534e03 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -44,3 +44,4 @@ #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index d2d328fc46..ddb64a2189 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp" #include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp" #include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp" +#include "ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp" #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp" @@ -14,6 +15,6 @@ #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" -#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp index a7eeb3c0e3..efa1ccb311 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp @@ -22,7 +22,7 @@ // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // -// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) +// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated) // * this could be larger than actual, since actual tokens are on GPU // // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 30e68996b6..340f6cb9e5 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -15,6 +15,10 @@ namespace ck_tile { #define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \ static_cast(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24)) +#ifndef MOE_SORTING_USE_EX_KERNEL +#define MOE_SORTING_USE_EX_KERNEL 1 +#endif + // clang-format off // [indexing implementation-1] // using M_a as constexpr block_size to partition all tokens into different slices @@ -28,7 +32,7 @@ namespace ck_tile { // (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 // weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] // -// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) +// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated) // * this could be larger than actual, since actual tokens are on GPU // // sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] @@ -55,6 +59,34 @@ namespace ck_tile { // num_tokens_post_padded_ptr : [28] // num_sorted_tiles_ptr : [7] // +// skip_experts_with_zero_tokens(SkipExpertsWithZeroTokens) +// if enabled, the expert with no tokens will be skipped, in stead of padding to at least 1 unit_size(M_a) +// +// (pack below tensor, skip element marked with `-`) +// Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y +// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] +// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -| +// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o] +// +// +// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 5] +// num_tokens_post_padded_ptr : [24] +// +// * local_expert_mask : indicate local expert mask used on current GPU (used for EP case) +// and modify the output expert-ID, because we will only have enbaled expert on specific GPU. +// we call expert input to this kernel as "global expert id", output as "local expert id" +// +// * local_expert_mask : [1, 0, 1, 1, 0, 1] (mask out expert-id=1, 4) +// +// (pack below tensor, skip element marked with `-`) +// Y Y Y Y - - - - Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y +// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] +// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -| +// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o] +// +// sorted_expert_ids_ptr : [0, 1, 2, 2, 3] (note original it was exper-id= 0, 2, 3, 5, but we produce "local expert id") +// num_tokens_post_padded_ptr : [20] +// // * different from vLLM // 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id // 2)need sorted_weight_ptr @@ -67,10 +99,80 @@ namespace ck_tile { // 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) // // max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) + + +CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int num_experts_) +{ + /* num_experts + 1 + * +--------------------------------------+ + * | | + * | | + * | | * -> sub-tokens + * | | + * | | + * +--------------------------------------+ + * | | 2 -> cumsum buffer + * +--------------------------------------+ + * + */ + int smem_cols = num_experts_ + 1; // usually experts is power of 2. padding here + int smem_rows = [&](){ + index_t target_occupancy_ = 2; + constexpr index_t total_ = 65536 / sizeof(int); + constexpr index_t sub_unroll = 8; + constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt + // at lease 2 lines, one for sub_token unroll, one for cumsum + // should be enough + if ((total_ / target_occupancy_) < ((cumsum_bufs+sub_unroll) * smem_cols)) { + if ((total_ / 1) < ((cumsum_bufs+sub_unroll) * smem_cols)) + throw std::runtime_error("too many num_experts, can't allocate smem"); + target_occupancy_ = 1; + } + int r = total_ / target_occupancy_ / smem_cols; + + // round to sub_unroll multipl + int r_for_sub_token = r - cumsum_bufs; + r_for_sub_token = min(r_for_sub_token, num_tokens_); + r_for_sub_token = (r_for_sub_token + sub_unroll - 1) / sub_unroll * sub_unroll; + r_for_sub_token = max(r_for_sub_token, 1); + + if(r_for_sub_token > 1) + { + int r_unroll_ = r_for_sub_token / sub_unroll; + + + // round to 1x/2x/4x/8x number of sub_unroll + int clz_ = __builtin_clz(r_unroll_); // 0b1:31 0b2:30, 0b3:30, 0b4:29 + int mask_ = (1 << (31 - clz_)) - 1; + + + mask_ = mask_ > 0b111 ? 0b111 : mask_; //clamp to 8x at most + mask_ = ~mask_; + //printf("r_unroll_:%d, clz:%d, mask:%x\n", r_unroll_, clz_, mask_); fflush(stdout); + + r_for_sub_token = (r_unroll_ & mask_) * sub_unroll; + } + + // final check + if( (r_for_sub_token + cumsum_bufs * smem_cols * target_occupancy_ ) >= total_ ) { + throw std::runtime_error("can't run this kernel, request LDS over size"); + } + + return r_for_sub_token + cumsum_bufs; + }(); + + // printf("r:%d, c:%d\n", smem_rows, smem_cols); + + return ck_tile::make_tuple(smem_rows, smem_cols); +} + struct MoeSortingHostArgs { const void* p_topk_ids; // [token, topk] const void* p_weights; // [token, topk] + + const void* p_local_expert_mask; + void* p_sorted_token_ids; void* p_sorted_weights; void* p_sorted_expert_ids; @@ -101,6 +203,7 @@ struct MoeSortingKernel { const void* p_topk_ids; const void* p_weights; + const void* p_local_expert_mask; void* p_sorted_token_ids; void* p_sorted_weights; void* p_sorted_expert_ids; @@ -111,8 +214,11 @@ struct MoeSortingKernel index_t moe_buf_bytes; index_t tokens_per_thread; + index_t smem_rows; mdiv unit_size_mdiv; mdiv topk_mdiv; + mdiv expert_mdiv; + // mdiv sub_tokens_mdiv; }; CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) @@ -123,15 +229,25 @@ struct MoeSortingKernel CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h) { +#if MOE_SORTING_USE_EX_KERNEL + (void)h; + return dim3(256); +#else return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size())); +#endif } // in byte CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h) { +#if MOE_SORTING_USE_EX_KERNEL + auto [smem_rows, smem_cols] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts); + return smem_rows * smem_cols * sizeof(int); +#else const auto blocks = BlockSize(h); // usually num_experts is power of 2, we pad 1 dword here for the row-size return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t); +#endif } CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) @@ -139,6 +255,7 @@ struct MoeSortingKernel Kargs k; k.p_topk_ids = h.p_topk_ids; k.p_weights = h.p_weights; + k.p_local_expert_mask = h.p_local_expert_mask; k.p_sorted_token_ids = h.p_sorted_token_ids; k.p_sorted_weights = h.p_sorted_weights; k.p_sorted_expert_ids = h.p_sorted_expert_ids; @@ -152,10 +269,18 @@ struct MoeSortingKernel k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x); k.unit_size_mdiv = mdiv{static_cast(h.unit_size)}; k.topk_mdiv = mdiv{static_cast(h.topk)}; + k.smem_rows = [&](){ + auto [r_, c_] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts); + (void) c_; + return r_; + }(); + k.expert_mdiv = mdiv{static_cast(h.num_experts)}; + // k.sub_tokens_mdiv = mdiv{static_cast(k.smem_rows - 1)}; return k; } - // [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....] + // [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....] + // NOTE: wave_size need at least be 16!! dpp 16 is one row template __device__ inline void wave_cumsum(data_t& thread_data) const { @@ -196,6 +321,40 @@ struct MoeSortingKernel bank_mask, bound_ctrl))); // row_shr:4 } + if constexpr(wave_size == 8) { + + // wave-size=8 need one extra shift + thread_data = + reduce_op(thread_data, + __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x118, + row_mask, + bank_mask, + bound_ctrl))); // row_shr:8 +#if 0 + constexpr int bank_mask_0_7 = 0b1100; + auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; }; + thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t, + __builtin_amdgcn_update_dpp(0, /* old value */ + __builtin_bit_cast(int, thread_data), + 0x157, + row_mask, + bank_mask_0_7, + bound_ctrl))// row_newbcast:7 + ); +#else + data_t xxx =__builtin_bit_cast(data_t, + __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), + 0x157, + row_mask, + bank_mask, + bound_ctrl)); // row_newbcast:7 + + data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx; + thread_data = thread_data - yyy; +#endif + + } if constexpr(wave_size > 8) { thread_data = @@ -224,6 +383,36 @@ struct MoeSortingKernel } } + // reduce single pixel within a wave + template + __device__ static constexpr T wave_reduce(T local, F reduce_f, number = {}) + { + // constexpr int wave_size = 64; + // constexpr int reduce_stage = 6; // 1<<6=64 + // clang-format off + constexpr int reduce_stage = [](){ + if constexpr(wave_size_ == 2) return 1; + else if constexpr(wave_size_ == 4) return 2; + else if constexpr(wave_size_ == 8) return 3; + else if constexpr(wave_size_ == 16) return 4; + else if constexpr(wave_size_ == 32) return 5; + else if constexpr(wave_size_ == 64) return 6; + else return 0; + }(); + // clang-format on + T v_local = local; +#pragma unroll reduce_stage + for(int i_stage = 0; i_stage < reduce_stage; i_stage++) + { + int src_lane = __lane_id() ^ (1 << i_stage); + int32_t v_remote_tmp = + __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast(v_local)); + T v_remote = bit_cast(v_remote_tmp); + v_local = reduce_f(v_local, v_remote); + } + return v_local; + } + CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const { return row * total_col + col; @@ -257,37 +446,37 @@ struct MoeSortingKernel index_t* shared_mem = reinterpret_cast(smem); index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts) - index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts+1); // 1: (num_experts + 1) + index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts + 1); // 1: (num_experts + 1) for(int i = 0; i < num_experts; ++i) { - tokens_cnts[calc_index(num_experts+1, tid + 1, i)] = 0; + tokens_cnts[calc_index(num_experts + 1, tid + 1, i)] = 0; } #pragma unroll Problem_::InternalLoadUnroll for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[calc_index(num_experts+1, tid + 1, topk_id[i])]; + ++tokens_cnts[calc_index(num_experts + 1, tid + 1, topk_id[i])]; } __syncthreads(); #if 1 if(tid < num_experts) { - tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0; + tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0; index_t local_c[8]; index_t prev_c = 0; // TODO: manually unroll. pragma unroll does not work well when we have dependency - for(int i = 1; i <= static_cast(blockDim.x); i+= 8) + for(int i = 1; i <= static_cast(blockDim.x); i += 8) { - local_c[0] = tokens_cnts[calc_index(num_experts+1, i + 0, tid)]; - local_c[1] = tokens_cnts[calc_index(num_experts+1, i + 1, tid)]; - local_c[2] = tokens_cnts[calc_index(num_experts+1, i + 2, tid)]; - local_c[3] = tokens_cnts[calc_index(num_experts+1, i + 3, tid)]; - local_c[4] = tokens_cnts[calc_index(num_experts+1, i + 4, tid)]; - local_c[5] = tokens_cnts[calc_index(num_experts+1, i + 5, tid)]; - local_c[6] = tokens_cnts[calc_index(num_experts+1, i + 6, tid)]; - local_c[7] = tokens_cnts[calc_index(num_experts+1, i + 7, tid)]; + local_c[0] = tokens_cnts[calc_index(num_experts + 1, i + 0, tid)]; + local_c[1] = tokens_cnts[calc_index(num_experts + 1, i + 1, tid)]; + local_c[2] = tokens_cnts[calc_index(num_experts + 1, i + 2, tid)]; + local_c[3] = tokens_cnts[calc_index(num_experts + 1, i + 3, tid)]; + local_c[4] = tokens_cnts[calc_index(num_experts + 1, i + 4, tid)]; + local_c[5] = tokens_cnts[calc_index(num_experts + 1, i + 5, tid)]; + local_c[6] = tokens_cnts[calc_index(num_experts + 1, i + 6, tid)]; + local_c[7] = tokens_cnts[calc_index(num_experts + 1, i + 7, tid)]; local_c[0] += prev_c; local_c[1] += local_c[0]; @@ -299,51 +488,57 @@ struct MoeSortingKernel local_c[7] += local_c[6]; prev_c = local_c[7]; - tokens_cnts[calc_index(num_experts+1, i + 0, tid)] = local_c[0]; - tokens_cnts[calc_index(num_experts+1, i + 1, tid)] = local_c[1]; - tokens_cnts[calc_index(num_experts+1, i + 2, tid)] = local_c[2]; - tokens_cnts[calc_index(num_experts+1, i + 3, tid)] = local_c[3]; - tokens_cnts[calc_index(num_experts+1, i + 4, tid)] = local_c[4]; - tokens_cnts[calc_index(num_experts+1, i + 5, tid)] = local_c[5]; - tokens_cnts[calc_index(num_experts+1, i + 6, tid)] = local_c[6]; - tokens_cnts[calc_index(num_experts+1, i + 7, tid)] = local_c[7]; + tokens_cnts[calc_index(num_experts + 1, i + 0, tid)] = local_c[0]; + tokens_cnts[calc_index(num_experts + 1, i + 1, tid)] = local_c[1]; + tokens_cnts[calc_index(num_experts + 1, i + 2, tid)] = local_c[2]; + tokens_cnts[calc_index(num_experts + 1, i + 3, tid)] = local_c[3]; + tokens_cnts[calc_index(num_experts + 1, i + 4, tid)] = local_c[4]; + tokens_cnts[calc_index(num_experts + 1, i + 5, tid)] = local_c[5]; + tokens_cnts[calc_index(num_experts + 1, i + 6, tid)] = local_c[6]; + tokens_cnts[calc_index(num_experts + 1, i + 7, tid)] = local_c[7]; } } #else - // TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic + // TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future + // heuristic { if(tid < num_experts) - tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0; - for(int i = 0; i < num_experts; i+=8) { + tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0; + for(int i = 0; i < num_experts; i += 8) + { index_t local_c[8]; - #pragma unroll - for(int j = 0; j < 8; j++) { - local_c[j] = tokens_cnts[calc_index(num_experts+1, tid+1, i+j)]; +#pragma unroll + for(int j = 0; j < 8; j++) + { + local_c[j] = tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)]; } - #pragma unroll - for(int j = 0; j < 8; j++) { +#pragma unroll + for(int j = 0; j < 8; j++) + { wave_cumsum(local_c[j]); } - #pragma unroll - for(int j = 0; j < 8; j++) { - tokens_cnts[calc_index(num_experts+1, tid+1, i+j)] = local_c[j]; +#pragma unroll + for(int j = 0; j < 8; j++) + { + tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)] = local_c[j]; } } } #endif __syncthreads(); - if constexpr (Problem::ExpertTile == 0) { + if constexpr(Problem::ExpertTile == 0) + { if(tid == 0) { cumsum[0] = 0; for(int i = 1; i <= num_experts; ++i) { auto current_units = [&]() { - index_t x_ = tokens_cnts[calc_index(num_experts+1, blockDim.x, i - 1)] + - unit_size_mdiv.divisor - 1; + index_t x_ = tokens_cnts[calc_index(num_experts + 1, blockDim.x, i - 1)] + + unit_size_mdiv.divisor - 1; index_t y_ = unit_size_mdiv.div(x_); return max(y_, 1) * unit_size_mdiv.divisor; }(); @@ -351,20 +546,24 @@ struct MoeSortingKernel } *p_total_tokens_post_pad = cumsum[num_experts]; } - } else { - // TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert) - // for simplicity, not check experts here. - int local_cnt = tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)]; + } + else + { + // TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= + // expert) for simplicity, not check experts here. + int local_cnt = tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)]; int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1); int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor; - int local_cumsum = padded_tokens_per_expert; + int local_cumsum = padded_tokens_per_expert; wave_cumsum(local_cumsum); - if(tid == (num_experts - 1)) { - cumsum[0] = 0; + if(tid == (num_experts - 1)) + { + cumsum[0] = 0; *p_total_tokens_post_pad = local_cumsum; } - if(tid < num_experts) { + if(tid < num_experts) + { cumsum[tid + 1] = local_cumsum; } } @@ -373,7 +572,7 @@ struct MoeSortingKernel if(tid < num_experts) { int e_start = cumsum[tid]; - int e_end = cumsum[tid + 1]; + int e_end = cumsum[tid + 1]; for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor) { p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid; @@ -383,8 +582,8 @@ struct MoeSortingKernel #pragma unroll Problem_::InternalLoadUnroll for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - index_t expert_id = topk_id[i]; - index_t local_cnt = tokens_cnts[calc_index(num_experts+1, tid, expert_id)]; + index_t expert_id = topk_id[i]; + index_t local_cnt = tokens_cnts[calc_index(num_experts + 1, tid, expert_id)]; index_t rank_post_pad = local_cnt + cumsum[expert_id]; #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID uint32_t curr_token_id, curr_topk_id; @@ -393,16 +592,17 @@ struct MoeSortingKernel #else p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i); #endif - p_sorted_weights[rank_post_pad] = weights[i]; - tokens_cnts[calc_index(num_experts+1, tid, expert_id)] = local_cnt+1; + p_sorted_weights[rank_post_pad] = weights[i]; + tokens_cnts[calc_index(num_experts + 1, tid, expert_id)] = local_cnt + 1; } - if constexpr (Problem::ExpertTile == 0) { + if constexpr(Problem::ExpertTile == 0) + { const index_t prefill_token = topk_mdiv.div(numel); if(tid < num_experts) { index_t expert_offset = - cumsum[tid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)]; + cumsum[tid] + tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)]; index_t expert_end = cumsum[tid + 1]; while(expert_offset < expert_end) { @@ -417,16 +617,19 @@ struct MoeSortingKernel } } } - else { + else + { const index_t prefill_token = topk_mdiv.div(numel); // TODO: only support expert-tile like 8, 16, 32 static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile; { - index_t eid = tid / experts_per_wave; - index_t expert_offset = - cumsum[eid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, eid)] + tid % experts_per_wave; + index_t eid = tid / experts_per_wave; + index_t expert_offset = cumsum[eid] + + tokens_cnts[calc_index(num_experts + 1, blockDim.x, eid)] + + tid % experts_per_wave; index_t expert_end = cumsum[eid + 1]; - if(eid < num_experts) { + if(eid < num_experts) + { while(expert_offset < expert_end) { #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID @@ -436,10 +639,363 @@ struct MoeSortingKernel p_sorted_token_ids[expert_offset] = prefill_token; #endif p_sorted_weights[expert_offset] = static_cast(0.0); - expert_offset+=experts_per_wave; + expert_offset += experts_per_wave; } } - } + } + } + } + + // only support index_t, and single pixel access + struct simple_smem_indexer + { + index_t* smem; + index_t row_stride; + + // this is 2D + CK_TILE_DEVICE simple_smem_indexer(index_t* smem_, index_t row_stride_) + : smem(smem_), row_stride(row_stride_) + { + } + CK_TILE_DEVICE const index_t& operator()(index_t i_row, index_t i_col) const + { + return smem[i_row * row_stride + i_col]; + } + CK_TILE_DEVICE index_t& operator()(index_t i_row, index_t i_col) + { + return smem[i_row * row_stride + i_col]; + } + + // this is 1D or linear + CK_TILE_DEVICE simple_smem_indexer(index_t* smem_) : smem(smem_), row_stride(0) {} + CK_TILE_DEVICE const index_t& operator()(index_t idx) const { return smem[idx]; } + CK_TILE_DEVICE index_t& operator()(index_t idx) { return smem[idx]; } + }; + + CK_TILE_DEVICE void + moe_align_block_size_kernel_ex(const IndexType* __restrict__ topk_id, + const WeightType* __restrict__ weights, + const IndexType* __restrict__ local_expert_mask, + index_t* p_sorted_token_ids, + WeightType* p_sorted_weights, + index_t* p_sorted_expert_ids, + index_t* p_total_tokens_post_pad, + const index_t num_experts, + const index_t tokens, + const mdiv unit_size_mdiv, + const mdiv topk_mdiv, + const mdiv expert_mdiv, + const index_t smem_rows, + void* smem) const + { + const index_t tid = static_cast(threadIdx.x); + const index_t wid = __builtin_amdgcn_readfirstlane(tid / warpSize); + const index_t lid = __lane_id(); + constexpr index_t block_size = 256; // blockDim.x; + const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor; + const index_t topk = topk_mdiv.divisor; + auto f_sum = [](auto x_, auto y_) { return x_ + y_; }; + + const index_t smem_cols = num_experts + 1; + + simple_smem_indexer smem_cumsum{reinterpret_cast(smem) + 0}; + simple_smem_indexer smem_cumdup{reinterpret_cast(smem) + smem_cols}; + simple_smem_indexer smem_tokens{reinterpret_cast(smem) + 2 * smem_cols, + smem_cols}; + + // #pragma unroll 8 + for(int i = tid; i < (sub_tokens * num_experts); i += block_size) + { + uint32_t curr_token_id, curr_expert_id; + expert_mdiv.divmod(i, curr_token_id, curr_expert_id); + smem_tokens(curr_token_id, curr_expert_id) = 0; + } + __syncthreads(); + + for(int i_token = 0; i_token < tokens; i_token += sub_tokens) + { + // NOTE: below for loop can't have barrier inside!! + for(int i = tid; i < (sub_tokens * topk); i += block_size) + { + uint32_t curr_token_id, curr_topk_id; + topk_mdiv.divmod(i, curr_token_id, curr_topk_id); + int i_t = i_token + curr_token_id; + + if(i_t < tokens) + { + int eid = topk_id[i_t * topk + curr_topk_id]; + + if constexpr(Problem::SubTokenOneShot) + smem_tokens(curr_token_id, eid) = curr_topk_id + 1; + else + smem_tokens(curr_token_id, eid)++; + } + __builtin_amdgcn_s_waitcnt(0xc07f); + } + __syncthreads(); // make sure different i_token iteration not overlap by different wave + } + + // counting + if(tid == 0) + { + smem_cumsum(0) = 0; + // smem_cumdup(0) = 0; + } + + { + constexpr int lane_group_sz = 8; + int lane_group_id = tid / lane_group_sz; + int lane_group_os = tid % lane_group_sz; + constexpr int lane_group_nm = block_size / lane_group_sz; + + for(int i_e = lane_group_id; i_e < num_experts; i_e += lane_group_nm) + { + index_t local_c[Problem::SubTokenTile]; + index_t cnt = 0; + + for(int i = 0; i < sub_tokens; i += 8 * Problem::SubTokenTile) + { +#pragma unroll Problem::SubTokenTile + for(int j = 0; j < Problem::SubTokenTile; j++) + { + local_c[j] = smem_tokens(i + j * 8 + lane_group_os, i_e); + if constexpr(Problem::SubTokenOneShot) + { + local_c[j] = local_c[j] != 0 ? 1 : 0; + } + } + +#pragma unroll Problem::SubTokenTile + for(int j = 0; j < Problem::SubTokenTile; j++) + { + cnt += wave_reduce(local_c[j], f_sum, number<8>{}); + } + } + if(lane_group_os == 0) + smem_cumsum(i_e + 1) = cnt; + } + } + + if constexpr(Problem::LocalExpertMasking) + { + smem_cumdup(0) = 0; + for(int i_e = tid; i_e < num_experts; i_e += block_size) + { + // reuse this buffer + smem_cumdup(i_e + 1) = local_expert_mask[i_e]; + } + } + + __syncthreads(); + + { + if(wid == 0) + { + // NOTE: under this block can never use __syncthreads! + int i_e_ = 0; + int local_cumsum_ = 0; + for(; i_e_ < num_experts; i_e_ += warpSize) + { + int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0); + int local_cnt = smem_cumsum(i_e_ + lid + 1); + int blocks_pers_expert = + unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1); + + int pre_cumsum_masking = [&]() { + if constexpr(Problem::LocalExpertMasking) + return smem_cumdup(lid == 0 ? i_e_ : 0); + else + return 0; // not used + }(); + int local_masking = [&]() { + if constexpr(Problem::LocalExpertMasking) + return smem_cumdup(i_e_ + lid + 1); + else + return 0; // not used + }(); + int padded_tokens_per_expert = [&]() { + int x_ = [&]() { + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + // if local_cnt is zero, blocks_pers_expert will be zero + // this is what we want to achieve + return blocks_pers_expert * unit_size_mdiv.divisor; + } + else + { + return max(blocks_pers_expert, 1) * unit_size_mdiv.divisor; + } + }(); + if constexpr(Problem::LocalExpertMasking) + { + return local_masking ? x_ : 0; + } + else + return x_; + }(); + + local_cumsum_ = padded_tokens_per_expert; + local_cumsum_ += pre_cumsum_; // note pre_cumsum must be added after local + // cumsum padded in case local cumsum is zero, but + // pre_sumsum has value, which will result int + // zero local cumsum(but we want at least padded) + wave_cumsum(local_cumsum_); + + if((i_e_ + lid) < num_experts) + smem_cumsum(i_e_ + lid + 1) = local_cumsum_; + + if constexpr(Problem::LocalExpertMasking) + { + local_masking += pre_cumsum_masking; + wave_cumsum(local_masking); + if((i_e_ + lid) < num_experts) + smem_cumdup(i_e_ + lid + 1) = local_masking; + } + + // NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt() + // for above write however __syncthreads will cause barrier with waves other + // than 0(which is not we want) + __builtin_amdgcn_s_waitcnt(0xc07f); + } + if((lid + i_e_ - warpSize) == (num_experts - 1)) + { + *p_total_tokens_post_pad = local_cumsum_; + } + } + __syncthreads(); + } + + for(int i_e = tid; i_e < num_experts; i_e += block_size) + { + int e_start = smem_cumsum(i_e); + int e_end = smem_cumsum(i_e + 1); + + int expert_id = [&]() { + if constexpr(Problem::LocalExpertMasking) + { + // local expert id from cumsum + return smem_cumdup(i_e); + } + else + return i_e; + }(); + + smem_cumdup(i_e) = e_start; // duplicate cumsum for later use + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + if(e_start == e_end) // skip zero token expert + continue; + } + + if constexpr(Problem::LocalExpertMasking) + { + if(local_expert_mask[i_e] == 0) + continue; + } + + for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor) + { + p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id; + } + } + smem_cumdup(num_experts) = smem_cumsum(num_experts); + + // fill the p_sorted_token_ids/p_sorted_weights + for(int i_token = 0; i_token < tokens; i_token += sub_tokens) + { + if constexpr(!Problem::SubTokenOneShot) + { + // clear every time + for(int i = tid; i < (sub_tokens * num_experts); i += block_size) + { + uint32_t curr_token_id, curr_expert_id; + expert_mdiv.divmod(i, curr_token_id, curr_expert_id); + smem_tokens(curr_token_id, curr_expert_id) = 0; + } + __syncthreads(); + + // load again + for(int i = tid; i < (sub_tokens * topk); i += block_size) + { + uint32_t curr_token_id_, curr_topk_id_; + topk_mdiv.divmod(i, curr_token_id_, curr_topk_id_); + int curr_token_id = static_cast(curr_token_id_); + int curr_topk_id = static_cast(curr_topk_id_); + int i_t = i_token + curr_token_id; + if(i_t < tokens) + { + int eid = topk_id[i_t * topk + curr_topk_id]; + smem_tokens(curr_token_id, eid) = curr_topk_id + 1; // at least 1 + } + } + __syncthreads(); + } + + { + constexpr int lane_group_sz = 8; + int lane_group_id = tid / lane_group_sz; + int lane_group_os = tid % lane_group_sz; + constexpr int lane_group_nm = block_size / lane_group_sz; + for(int eid = lane_group_id; eid < num_experts; eid += lane_group_nm) + { + if constexpr(Problem::LocalExpertMasking) + { + if(local_expert_mask[eid] == 0) + continue; + } + int position = smem_cumsum(eid); + for(int i_sub_token = lane_group_os; i_sub_token < sub_tokens; + i_sub_token += lane_group_sz) + { + auto x = smem_tokens(i_sub_token, eid); + + int local_cnt_cache = x != 0 ? 1 : 0; + int local_cnt = local_cnt_cache; + wave_cumsum(local_cnt); + if(x != 0) + { + // now x is topk value +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[position + local_cnt - 1] = + MOE_SORTING_MOCK_ID(i_token + i_sub_token, x - 1); +#else + p_sorted_token_ids[position + local_cnt - 1] = i_token + i_sub_token; +#endif + p_sorted_weights[position + local_cnt - 1] = + weights[(i_token + i_sub_token) * topk + x - 1]; + } + + int remote_cnt = __builtin_amdgcn_ds_bpermute( + (lane_group_sz * (lane_group_id + 1) - 1) << 2, local_cnt); + + position += remote_cnt; + } + smem_cumsum(eid) = position; + } + } + __syncthreads(); + } + + // add the skip number + for(int eid = tid; eid < num_experts; eid += block_size) + { + int e_start = smem_cumsum(eid); + int e_end = smem_cumdup(eid + 1); + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + if(e_start == e_end) // skip zero token expert + continue; + } + while(e_start < e_end) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start] = MOE_SORTING_MOCK_ID(tokens, topk); +#else + p_sorted_token_ids[e_start] = tokens; +#endif + p_sorted_weights[e_start] = static_cast(0.0); + e_start++; + } } } @@ -456,6 +1012,24 @@ struct MoeSortingKernel } const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor; extern __shared__ char smem[]; +#if MOE_SORTING_USE_EX_KERNEL + (void)numel; + return moe_align_block_size_kernel_ex( + static_cast(kargs.p_topk_ids), + static_cast(kargs.p_weights), + static_cast(kargs.p_local_expert_mask), + static_cast(kargs.p_sorted_token_ids), + static_cast(kargs.p_sorted_weights), + static_cast(kargs.p_sorted_expert_ids), + static_cast(kargs.p_total_tokens_post_pad), + kargs.num_experts, + kargs.tokens, + kargs.unit_size_mdiv, + kargs.topk_mdiv, + kargs.expert_mdiv, + kargs.smem_rows, + smem); +#else return moe_align_block_size_kernel(static_cast(kargs.p_topk_ids), static_cast(kargs.p_weights), static_cast(kargs.p_sorted_token_ids), @@ -468,6 +1042,7 @@ struct MoeSortingKernel kargs.unit_size_mdiv, kargs.topk_mdiv, smem); +#endif } }; diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp new file mode 100644 index 0000000000..15effe7118 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include +#include + +namespace ck_tile { + +template +struct MoeSortingProblem +{ + // TODO: this kernel only support warp per row + using WeightType = remove_cvref_t; + using IndexType = remove_cvref_t; + + static constexpr index_t WarpSize = get_warp_size(); + static constexpr index_t WarpsPerBlock = 1; + static constexpr index_t InternalLoadUnroll = + InternalLoadUnroll_; // TODO: need better design(like tile size) + static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out +}; + +template +struct MoeSortingProblemEx +{ + // TODO: this kernel only support warp per row + using WeightType = remove_cvref_t; + using IndexType = remove_cvref_t; + + static constexpr index_t WarpSize = get_warp_size(); + static constexpr index_t WarpsPerBlock = 1; + static constexpr index_t SubTokenTile = SubTokenTile_; + static constexpr bool SubTokenOneShot = SubTokenOneShot_; + static constexpr bool LocalExpertMasking = LocalExpertMasking_; + static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_; + static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8); + static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp deleted file mode 100644 index 50005c4402..0000000000 --- a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp +++ /dev/null @@ -1,28 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include -#include - -namespace ck_tile { - -template -struct MoeSortingProblem -{ - // TODO: this kernel only support warp per row - using WeightType = remove_cvref_t; - using IndexType = remove_cvref_t; - - static constexpr index_t WarpSize = get_warp_size(); - static constexpr index_t WarpsPerBlock = 1; - static constexpr index_t InternalLoadUnroll = - InternalLoadUnroll_; // TODO: need better design(like tile size) - static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out -}; -} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 5bbe0601b7..a94628a59a 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -46,3 +46,4 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 0f8bec3cf4..323c682f2c 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -1,9 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -57,6 +59,18 @@ struct BatchedGemmKernel : public GemmKernel, + concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock), + concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), + concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); + // clang-format on + } + struct BatchedGemmKernelArgs : GemmKernelArgs { index_t batch_stride_A; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index aa31d1fccf..4ed3006c89 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -75,6 +76,13 @@ struct GemmKernel static constexpr auto I1 = number<1>(); static constexpr auto I2 = number<2>(); + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm", gemm_prec_str, GemmPipeline::GetName()); + // clang-format on + } + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) { return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 13d3df02f9..751e7c0e1a 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -64,6 +64,18 @@ struct GroupedGemmKernel : public GemmKernel, + concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock), + concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), + concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); + // clang-format on + } + __host__ static auto GetWorkSpaceSize(const std::vector& gemm_descs) -> std::size_t { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 0a40ca359e..eec3886e2f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -10,6 +10,7 @@ #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -81,6 +82,15 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 using Base::PrefetchStages; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AgBgCrCompV3", BlockSize, + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', kPadM, kPadN, kPadK)); + // clang-format on + } + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Policy::template GetSmemSize(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index e23f0cda7d..f8dd2348cb 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -128,6 +129,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr auto TailNum = Problem::TailNum; static constexpr auto Scheduler = Problem::Scheduler; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AgBgCrMe", + concat('x', MPerBlock, NPerBlock, KPerBlock), + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', kPadM, kPadN, kPadK)); + // clang-format on + } + using Base::PrefetchStages; CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp index 6f51e6b8a9..b18bf603a9 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include +#include #include "ck_tile/core.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index d9f04a87c3..a2a14d1017 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -39,6 +40,18 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; + static constexpr index_t kLdsAlignmentInBytes = 16; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AGmemBGmemCRegV1", + concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize), + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', kPadM, kPadN, kPadK)); + // clang-format on + } + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -75,8 +88,9 @@ struct GemmPipelineAGmemBGmemCRegV1 auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); constexpr index_t a_lds_block_space_size_aligned = - integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * - 16; + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), + kLdsAlignmentInBytes) * + kLdsAlignmentInBytes; // B tile in LDS BDataType* p_b_lds = static_cast( diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index 0417035fb6..ce2dc9fb96 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -25,6 +26,13 @@ struct GemmPipelineAGmemBGmemCRegV2 static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kKPerBlock = BlockGemmShape::kK; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AGmemBGmemCRegV2", + concat('x', kMPerBlock, kNPerBlock, kKPerBlock, kBlockSize)); + // clang-format on + } CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index a69f72626c..dd631876b4 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -35,9 +36,19 @@ struct GemmPipelineProblemBase static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadK = Traits::kPadK; - static constexpr auto Scheduler = GemmPipelineScheduler::Default; - + static constexpr auto Scheduler = GemmPipelineScheduler::Default; static constexpr index_t VectorLoadSize = Traits::_VectorSize; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm_problem", + concat('x', VectorLoadSize, kBlockSize), + concat('x', kPadM, kPadN, kPadK), + Scheduler); + // clang-format on + } + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() { if constexpr(std::is_same_v) diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index 2522abe5ed..24a399f18d 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -19,6 +20,16 @@ struct TileGemmShape static constexpr index_t kM = BlockTile::at(number<0>{}); static constexpr index_t kN = BlockTile::at(number<1>{}); static constexpr index_t kK = BlockTile::at(number<2>{}); + + CK_TILE_HOST static std::string GetName() + { + // clang-format off + return concat('_', "tile_gemm_shape", + concat('x', kM, kN, kK, NumWarps), + concat('x', BlockWarps::at(number<0>{}), BlockWarps::at(number<1>{}), BlockWarps::at(number<2>{})), + concat('x', (WarpTile::at(number<0>{})), WarpTile::at(number<1>{}), WarpTile::at(number<2>{}))); + // clang-format on + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index d54b7f60d6..93664ea138 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -8,3 +8,4 @@ #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index 47d986e1c2..afbb817db1 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -11,3 +11,4 @@ #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/norm_reduce.hpp b/include/ck_tile/ops/norm_reduce.hpp index 9392f8b439..7dc3e8b7e7 100644 --- a/include/ck_tile/ops/norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce.hpp @@ -8,3 +8,4 @@ #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index f3abe84e46..1cc3d9cbc3 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -7,3 +7,4 @@ #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index b817d09c72..80ead84e85 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -9,3 +9,4 @@ #include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index 73fd6bfb0e..3eec2a1ab6 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -11,3 +11,4 @@ #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/smoothquant.hpp b/include/ck_tile/ops/smoothquant.hpp index 3fe1b5b213..dc164dc1a0 100644 --- a/include/ck_tile/ops/smoothquant.hpp +++ b/include/ck_tile/ops/smoothquant.hpp @@ -11,3 +11,4 @@ #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index 391609622a..b23e869d81 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -7,3 +7,4 @@ #include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index 40b9edd72f..1dc563f757 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -7,3 +7,4 @@ #include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index efc1d17637..d0a810de4f 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -9,3 +9,4 @@ #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp"