mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into develop
This commit is contained in:
@@ -92,6 +92,7 @@ endif()
|
||||
add_compile_options(-Wno-bit-int-extension)
|
||||
add_compile_options(-Wno-pass-failed)
|
||||
add_compile_options(-Wno-switch-default)
|
||||
add_compile_options(-Wno-unique-object-duplication)
|
||||
|
||||
if(DL_KERNELS)
|
||||
add_definitions(-DDL_KERNELS)
|
||||
|
||||
6
Jenkinsfile
vendored
6
Jenkinsfile
vendored
@@ -117,7 +117,7 @@ def getDockerImage(Map conf=[:]){
|
||||
{
|
||||
echo "Pulling down image: ${image}"
|
||||
retimage = docker.image("${image}")
|
||||
withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) {
|
||||
withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) {
|
||||
retimage.pull()
|
||||
}
|
||||
}
|
||||
@@ -148,7 +148,7 @@ def buildDocker(install_prefix){
|
||||
//force building the new docker if that parameter is true
|
||||
echo "Building image: ${image_name}"
|
||||
retimage = docker.build("${image_name}", dockerArgs)
|
||||
withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) {
|
||||
withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) {
|
||||
retimage.push()
|
||||
}
|
||||
sh 'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi'
|
||||
@@ -162,7 +162,7 @@ def buildDocker(install_prefix){
|
||||
catch(Exception ex){
|
||||
echo "Unable to locate image: ${image_name}. Building image now"
|
||||
retimage = docker.build("${image_name}", dockerArgs + ' .')
|
||||
withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) {
|
||||
withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) {
|
||||
retimage.push()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ck/host/types.hpp"
|
||||
#include "ck/host/operation/gemm.hpp"
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace device_batched_gemm_softmax_gemm {
|
||||
|
||||
// defines all values need for an instance of fwd conv
|
||||
struct Operation_Xdl_CShuffle
|
||||
{
|
||||
// returns a vector of instances, only given fusion operators: will use default problem spec
|
||||
static std::vector<std::vector<Operation_Xdl_CShuffle>>
|
||||
CreateOperations(const std::string& prologue, const std::string& epilogue);
|
||||
// returns a vector of instances, given a problem spec and fusion operators
|
||||
static std::vector<Operation_Xdl_CShuffle>
|
||||
CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue);
|
||||
TensorDesc A{};
|
||||
TensorDesc B{};
|
||||
TensorDesc B1{};
|
||||
TensorDesc C{};
|
||||
DataType acc = DataType::Float;
|
||||
DataType cs_type = DataType::Half;
|
||||
std::string a_elem_op = PassThrough;
|
||||
std::string b_elem_op = PassThrough;
|
||||
std::string b1_elem_op = PassThrough;
|
||||
std::string c_elem_op = PassThrough;
|
||||
std::string acc_elem_op = Scale;
|
||||
std::string prologue = "";
|
||||
std::string epilogue = "";
|
||||
std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default";
|
||||
// tuning parameters
|
||||
operation::TileDescGemmGemm tile_desc{};
|
||||
operation::BlockTransferDesc a_block_transfer{};
|
||||
operation::BlockTransferDesc b0_block_transfer{};
|
||||
operation::BlockTransferDesc b1_block_transfer{};
|
||||
operation::CShuffleDesc cshuffle{};
|
||||
operation::CBlockTransferDesc c_block_transfer{};
|
||||
|
||||
bool mask_out_upper_triangle = false;
|
||||
|
||||
// functions to update fusion operators if provided
|
||||
void update_prologue(const std::string& prologue);
|
||||
void update_epilogue(const std::string& epilogue);
|
||||
/**constexpr**/ bool
|
||||
IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_, std::size_t Gemm1NRaw_);
|
||||
// returns a templated instance
|
||||
Solution ToSolution() const;
|
||||
};
|
||||
|
||||
} // namespace device_batched_gemm_softmax_gemm
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,47 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "ck/host/types.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace device_batched_gemm_softmax_gemm {
|
||||
|
||||
// defines the problem specification for a GEMM operation
|
||||
struct Problem
|
||||
{
|
||||
std::size_t M = 0;
|
||||
std::size_t N = 0;
|
||||
std::size_t K = 0;
|
||||
std::size_t O = 0;
|
||||
bool TransA = false;
|
||||
bool TransB = false;
|
||||
bool TransB1 = false;
|
||||
bool TransC = false;
|
||||
DataType ADataType = DataType::Half;
|
||||
DataType BDataType = DataType::Half;
|
||||
DataType B1DataType = DataType::Half;
|
||||
DataType CDataType = DataType::Half;
|
||||
std::string AElementOp = PassThrough;
|
||||
std::string BElementOp = PassThrough;
|
||||
std::string B1ElementOp = PassThrough;
|
||||
std::string CElementOp = PassThrough;
|
||||
std::string AccElementOp = Scale;
|
||||
|
||||
// returns the correct device op file for the operation
|
||||
std::string GetIncludeHeader() const;
|
||||
|
||||
// returns a list of instances based on the problem spec and provided fusion operations
|
||||
std::vector<Solution> GetSolutions(const std::string& arch,
|
||||
const std::string& prologue,
|
||||
const std::string& epilogue) const;
|
||||
};
|
||||
|
||||
} // namespace device_batched_gemm_softmax_gemm
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
@@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle
|
||||
operation::BlockTransferDesc b_block_transfer{};
|
||||
operation::CShuffleDesc cshuffle{};
|
||||
operation::CBlockTransferDesc c_block_transfer{};
|
||||
LoopScheduler loop_scheduler{};
|
||||
PipelineVersion pipeline_version{};
|
||||
|
||||
// functions to update fusion operators if provided
|
||||
void update_prologue(const std::string& prologue);
|
||||
|
||||
@@ -23,6 +23,26 @@ struct TileDesc
|
||||
int n_Xdl_per_wave = 0;
|
||||
int num_gemmk_prefetch_stage = 0;
|
||||
};
|
||||
|
||||
struct TileDescGemmGemm
|
||||
{
|
||||
int block_size = 0;
|
||||
int gemm01_m_per_block = 0;
|
||||
int gemm0_n_per_block = 0;
|
||||
int gemm0_k_per_block = 0;
|
||||
int gemm1_n_per_block = 0;
|
||||
int gemm1_k_per_block = 0;
|
||||
int ak1 = 0;
|
||||
int bk1 = 0;
|
||||
int b1k1 = 0;
|
||||
int m_per_XDL = 0;
|
||||
int n_per_XDL = 0;
|
||||
int gemm0_m_Xdl_per_wave = 0;
|
||||
int gemm0_n_Xdl_per_wave = 0;
|
||||
int gemm1_n_Xdl_per_wave = 0;
|
||||
int num_gemmk_prefetch_stage = 0;
|
||||
};
|
||||
|
||||
struct BlockTransferDesc
|
||||
{
|
||||
std::string thread_cluster_length = "";
|
||||
|
||||
@@ -66,6 +66,20 @@ enum class GemmType
|
||||
};
|
||||
std::string ToString(GemmType gt);
|
||||
|
||||
enum class LoopScheduler
|
||||
{
|
||||
Default,
|
||||
Interwave,
|
||||
};
|
||||
std::string ToString(LoopScheduler ls);
|
||||
|
||||
enum class PipelineVersion
|
||||
{
|
||||
v1,
|
||||
v2
|
||||
};
|
||||
std::string ToString(PipelineVersion pv);
|
||||
|
||||
struct TensorDesc
|
||||
{
|
||||
DataType element;
|
||||
@@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...});
|
||||
|
||||
constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough";
|
||||
constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear";
|
||||
constexpr const char* Scale = "ck::tensor_operation::element_wise::Scale";
|
||||
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
|
||||
38
codegen/src/device_batched_gemm_softmax_gemm.cpp
Normal file
38
codegen/src/device_batched_gemm_softmax_gemm.cpp
Normal file
@@ -0,0 +1,38 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
|
||||
#include "ck/host/utils.hpp"
|
||||
#include <algorithm>
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace device_batched_gemm_softmax_gemm {
|
||||
|
||||
// return the relevant device op file based on the operation
|
||||
std::string Problem::GetIncludeHeader() const
|
||||
{
|
||||
return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp";
|
||||
}
|
||||
|
||||
// returns templated instances when provided with a problem specification
|
||||
std::vector<Solution> Problem::GetSolutions(const std::string& arch,
|
||||
const std::string& prologue,
|
||||
const std::string& epilogue) const
|
||||
{
|
||||
if(get_xdlop_archs().count(arch) == 0)
|
||||
return {};
|
||||
auto ops = ck::host::device_batched_gemm_softmax_gemm::Operation_Xdl_CShuffle::CreateOperations(
|
||||
*this, prologue, epilogue); // obtains vector of instances
|
||||
std::vector<Solution> result;
|
||||
std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) {
|
||||
return op.ToSolution(); // template instance with correct values
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace device_batched_gemm_softmax_gemm
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,408 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
#include "ck/host/utils.hpp"
|
||||
#include <cassert>
|
||||
|
||||
namespace ck {
|
||||
namespace host {
|
||||
namespace device_batched_gemm_softmax_gemm {
|
||||
|
||||
// calculate appropriate Gemm Specification based on input tensor dimensions
|
||||
std::string GetGemmSpec(const std::size_t m,
|
||||
const std::size_t n,
|
||||
const std::size_t k,
|
||||
const std::size_t n1,
|
||||
const std::size_t m_per_block,
|
||||
const std::size_t n_per_block,
|
||||
const std::size_t k_per_block,
|
||||
const std::size_t n1_per_block)
|
||||
{
|
||||
std::string spec = "";
|
||||
if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0)
|
||||
spec += "M";
|
||||
if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0)
|
||||
spec += "N";
|
||||
if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0)
|
||||
spec += "K";
|
||||
if(integer_divide_ceil(n1, n1_per_block) * n1_per_block - n1 != 0)
|
||||
spec += "O";
|
||||
if(spec == "")
|
||||
return "ck::tensor_operation::device::GemmSpecialization::Default";
|
||||
|
||||
return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding";
|
||||
}
|
||||
|
||||
// function to update prologue/epilogue with user provided operation
|
||||
void Operation_Xdl_CShuffle::update_prologue(const std::string& pro)
|
||||
{
|
||||
if(!prologue.empty())
|
||||
{
|
||||
this->prologue = pro;
|
||||
}
|
||||
else
|
||||
{
|
||||
this->prologue = "";
|
||||
}
|
||||
}
|
||||
|
||||
void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
|
||||
{
|
||||
if(!epilogue.empty())
|
||||
{
|
||||
this->epilogue = epi;
|
||||
}
|
||||
else
|
||||
{
|
||||
this->epilogue = "";
|
||||
}
|
||||
}
|
||||
|
||||
// accounts for all possible combinations of Row/Col major
|
||||
static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
|
||||
|
||||
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
|
||||
// instances
|
||||
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
const Problem& prob, const std::string& prologue, const std::string& epilogue)
|
||||
{
|
||||
std::vector<Operation_Xdl_CShuffle> result;
|
||||
|
||||
std::vector<operation::TileDescGemmGemm> tile_descriptions = {
|
||||
// clang-format off
|
||||
// Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| NumGemmK|
|
||||
// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch|
|
||||
// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage|
|
||||
// | | | | | | | | | | | Wave| Wave| Wave| |
|
||||
{ 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1},
|
||||
{ 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1},
|
||||
{ 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1},
|
||||
{ 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1},
|
||||
{ 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
|
||||
{ 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1},
|
||||
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
|
||||
{ 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
|
||||
{ 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
|
||||
{ 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
|
||||
{ 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1},
|
||||
{ 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1},
|
||||
// Padded fallback kernel
|
||||
{ 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1},
|
||||
{ 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1},
|
||||
// Irregular k
|
||||
{ 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1},
|
||||
{ 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1},
|
||||
{ 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1},
|
||||
{ 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1},
|
||||
{ 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1},
|
||||
{ 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
const std::vector<operation::BlockTransferDesc> a_block_descriptions = {
|
||||
// clang-format off
|
||||
// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds|
|
||||
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
|
||||
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
|
||||
// | | | | | | |
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
// Padded fallback kernel
|
||||
{ S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true},
|
||||
// Irregular k
|
||||
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
|
||||
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
|
||||
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
|
||||
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
|
||||
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
|
||||
{ S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
const std::vector<operation::BlockTransferDesc> b1_block_descriptions = {
|
||||
// clang-format off
|
||||
// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds|
|
||||
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|
|
||||
// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
|
||||
// | | | | | | |
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
// Padded fallback kernel
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
// Irregular k
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
{ S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
std::vector<operation::CShuffleDesc> cshuffle_descriptions = {
|
||||
// clang-format off
|
||||
// CShuffle| CShuffle|
|
||||
// MXdlPerWave| NXdlPerWave|
|
||||
// PerShuffle| PerShuffle|
|
||||
// | |
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 8},
|
||||
{ 1, 4},
|
||||
{ 1, 8},
|
||||
{ 1, 4},
|
||||
// Padded fallback kernel
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
// Irregular k
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
{ 1, 2},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
std::vector<operation::CBlockTransferDesc> c_block_descriptions = {
|
||||
// clang-format off
|
||||
// CBlockTransferClusterLengths| CBlockTransfer
|
||||
// _MBlock_MWaveMPerXdl| ScalarPerVector
|
||||
// _NBlock_NWaveNPerXdl| _NWaveNPerXdl
|
||||
// |
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1,16>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 16, 1,16>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
// Padded fallback kernel
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
// Irregular k
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
assert(tile_descriptions.size() == a_block_descriptions.size());
|
||||
assert(tile_descriptions.size() == b1_block_descriptions.size());
|
||||
assert(tile_descriptions.size() == cshuffle_descriptions.size());
|
||||
assert(tile_descriptions.size() == c_block_descriptions.size());
|
||||
|
||||
// Put all values together into a single operation > store into the result vector
|
||||
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
|
||||
{
|
||||
Operation_Xdl_CShuffle x;
|
||||
x.tile_desc = tile_descriptions[i];
|
||||
x.a_block_transfer = a_block_descriptions[i];
|
||||
x.b0_block_transfer = a_block_descriptions[i]; // b0 same as a
|
||||
x.b1_block_transfer = b1_block_descriptions[i];
|
||||
x.cshuffle = cshuffle_descriptions[i];
|
||||
x.c_block_transfer = c_block_descriptions[i];
|
||||
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
|
||||
x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
|
||||
x.B1 = TensorDesc{prob.B1DataType, ToLayout(prob.TransB1)};
|
||||
x.C = TensorDesc{prob.CDataType, ToLayout(prob.TransC)};
|
||||
x.a_elem_op = prob.AElementOp;
|
||||
x.b_elem_op = prob.BElementOp;
|
||||
x.b1_elem_op = prob.B1ElementOp;
|
||||
x.c_elem_op = prob.CElementOp;
|
||||
x.acc_elem_op = prob.AccElementOp;
|
||||
x.gemm_specialization = GetGemmSpec(prob.M,
|
||||
prob.N,
|
||||
prob.K,
|
||||
prob.O,
|
||||
x.tile_desc.gemm01_m_per_block,
|
||||
x.tile_desc.gemm0_n_per_block,
|
||||
x.tile_desc.gemm0_k_per_block,
|
||||
x.tile_desc.gemm1_n_per_block);
|
||||
x.update_prologue(prologue);
|
||||
x.update_epilogue(epilogue);
|
||||
x.mask_out_upper_triangle = true;
|
||||
result.push_back(x);
|
||||
|
||||
x.mask_out_upper_triangle = false;
|
||||
result.push_back(x);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// set up instances when not provided with a problem specification, use default operation values and
|
||||
// all possible layout combinations
|
||||
std::vector<std::vector<Operation_Xdl_CShuffle>>
|
||||
Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue)
|
||||
{
|
||||
Problem prob;
|
||||
prob.TransA = false;
|
||||
prob.TransB = true;
|
||||
prob.TransB1 = false;
|
||||
prob.TransC = false;
|
||||
|
||||
return {CreateOperations(prob, prologue, epilogue)};
|
||||
}
|
||||
|
||||
static const char* const DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate =
|
||||
"ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<${LayoutA}, "
|
||||
"${LayoutB0}, ${LayoutB1}, ${LayoutC}, ${ADataType}, ${B0DataType}, ${B1DataType}, "
|
||||
"${CDataType}, ${AccDataType}, ${CShuffleDataType}, ${AElementwiseOperation}, "
|
||||
"${B0ElementwiseOperation}, ${Acc0ElementwiseOperation}, ${B1ElementwiseOperation}, "
|
||||
"${CElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, "
|
||||
"${Gemm01MPerBlock}, ${Gemm0NPerBlock}, ${Gemm0KPerBlock}, ${Gemm1NPerBlock}, "
|
||||
"${Gemm1KPerBlock}, ${AK1}, ${BK1}, ${B1K1}, ${MPerXDL}, ${NPerXDL}, ${Gemm0MXdlPerWave}, "
|
||||
"${Gemm0NXdlPerWave}, ${Gemm1NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, "
|
||||
"${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, "
|
||||
"${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, "
|
||||
"${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, "
|
||||
"${B0BlockTransferThreadClusterLengths_BK0_N_BK1}, "
|
||||
"${B0BlockTransferThreadClusterArrangeOrder}, ${B0BlockTransferSrcAccessOrder}, "
|
||||
"${B0BlockTransferSrcVectorDim}, ${B0BlockTransferSrcScalarPerVector}, "
|
||||
"${B0BlockTransferDstScalarPerVector_BK1}, ${B0BlockLdsExtraN}, "
|
||||
"${B1BlockTransferThreadClusterLengths_BK0_N_BK1}, "
|
||||
"${B1BlockTransferThreadClusterArrangeOrder}, ${B1BlockTransferSrcAccessOrder}, "
|
||||
"${B1BlockTransferSrcVectorDim}, ${B1BlockTransferSrcScalarPerVector}, "
|
||||
"${B1BlockTransferDstScalarPerVector_BK1}, ${B1BlockLdsExtraN}, "
|
||||
"${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
|
||||
"${CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl}, "
|
||||
"${CBlockTransferScalarPerVector_NWaveNPerXdl}, ${MaskOutUpperTriangle}>";
|
||||
|
||||
// use hardcoded instances from vector of operations to substitute values into instance template
|
||||
Solution Operation_Xdl_CShuffle::ToSolution() const
|
||||
{
|
||||
std::unordered_map<std::string, std::string> values = {
|
||||
{"name",
|
||||
std::to_string(this->tile_desc.block_size) + "_" +
|
||||
std::to_string(this->tile_desc.gemm01_m_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.gemm0_n_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.gemm0_k_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.gemm1_n_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.gemm1_k_per_block) + "_" +
|
||||
std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" +
|
||||
std::to_string(this->tile_desc.b1k1) + "_" +
|
||||
std::to_string(this->tile_desc.m_per_XDL) + "_" +
|
||||
std::to_string(this->tile_desc.n_per_XDL) + "_" +
|
||||
std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" +
|
||||
std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave) + "_" +
|
||||
std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)},
|
||||
{"LayoutA", ToString(this->A.layout)},
|
||||
{"LayoutB0", ToString(this->B.layout)},
|
||||
{"LayoutB1", ToString(this->B1.layout)},
|
||||
{"LayoutC", ToString(this->C.layout)},
|
||||
{"ADataType", ToString(this->A.element)},
|
||||
{"B0DataType", ToString(this->B.element)},
|
||||
{"B1DataType", ToString(this->B1.element)},
|
||||
{"CDataType", ToString(this->C.element)},
|
||||
{"AccDataType", ToString(this->acc)},
|
||||
{"CShuffleDataType", ToString(this->cs_type)},
|
||||
{"AElementwiseOperation", this->a_elem_op},
|
||||
{"B0ElementwiseOperation", this->b_elem_op},
|
||||
{"Acc0ElementwiseOperation", this->acc_elem_op},
|
||||
{"B1ElementwiseOperation", this->b1_elem_op},
|
||||
{"CElementwiseOperation", this->c_elem_op},
|
||||
{"GemmSpecialization", this->gemm_specialization},
|
||||
{"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)},
|
||||
{"BlockSize", std::to_string(this->tile_desc.block_size)},
|
||||
{"Gemm01MPerBlock", std::to_string(this->tile_desc.gemm01_m_per_block)},
|
||||
{"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)},
|
||||
{"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)},
|
||||
{"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)},
|
||||
{"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)},
|
||||
{"AK1", std::to_string(this->tile_desc.ak1)},
|
||||
{"BK1", std::to_string(this->tile_desc.bk1)},
|
||||
{"B1K1", std::to_string(this->tile_desc.b1k1)},
|
||||
{"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)},
|
||||
{"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)},
|
||||
{"Gemm0MXdlPerWave", std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave)},
|
||||
{"Gemm0NXdlPerWave", std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave)},
|
||||
{"Gemm1NXdlPerWave", std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)},
|
||||
{"ABlockTransferThreadClusterLengths_AK0_M_AK1",
|
||||
this->a_block_transfer.thread_cluster_length},
|
||||
{"ABlockTransferThreadClusterArrangeOrder",
|
||||
this->a_block_transfer.thread_cluster_arrange_order},
|
||||
{"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order},
|
||||
{"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)},
|
||||
{"ABlockTransferSrcScalarPerVector",
|
||||
std::to_string(this->a_block_transfer.src_scalar_per_vector)},
|
||||
{"ABlockTransferDstScalarPerVector_AK1",
|
||||
std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)},
|
||||
{"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)},
|
||||
{"B0BlockTransferThreadClusterLengths_BK0_N_BK1",
|
||||
this->b0_block_transfer.thread_cluster_length},
|
||||
{"B0BlockTransferThreadClusterArrangeOrder",
|
||||
this->b0_block_transfer.thread_cluster_arrange_order},
|
||||
{"B0BlockTransferSrcAccessOrder", this->b0_block_transfer.src_access_order},
|
||||
{"B0BlockTransferSrcVectorDim", std::to_string(this->b0_block_transfer.src_vec_dim)},
|
||||
{"B0BlockTransferSrcScalarPerVector",
|
||||
std::to_string(this->b0_block_transfer.src_scalar_per_vector)},
|
||||
{"B0BlockTransferDstScalarPerVector_BK1",
|
||||
std::to_string(this->b0_block_transfer.dst_scalar_per_vector_k1)},
|
||||
{"B0BlockLdsExtraN", std::to_string(this->b0_block_transfer.lds_add_extra_dim)},
|
||||
{"B1BlockTransferThreadClusterLengths_BK0_N_BK1",
|
||||
this->b1_block_transfer.thread_cluster_length},
|
||||
{"B1BlockTransferThreadClusterArrangeOrder",
|
||||
this->b1_block_transfer.thread_cluster_arrange_order},
|
||||
{"B1BlockTransferSrcAccessOrder", this->b1_block_transfer.src_access_order},
|
||||
{"B1BlockTransferSrcVectorDim", std::to_string(this->b1_block_transfer.src_vec_dim)},
|
||||
{"B1BlockTransferSrcScalarPerVector",
|
||||
std::to_string(this->b1_block_transfer.src_scalar_per_vector)},
|
||||
{"B1BlockTransferDstScalarPerVector_BK1",
|
||||
std::to_string(this->b1_block_transfer.dst_scalar_per_vector_k1)},
|
||||
{"B1BlockLdsExtraN", std::to_string(this->b1_block_transfer.lds_add_extra_dim)},
|
||||
{"CShuffleMXdlPerWavePerShuffle",
|
||||
std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)},
|
||||
{"CShuffleNXdlPerWavePerShuffle",
|
||||
std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)},
|
||||
{"CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl",
|
||||
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
|
||||
{"CBlockTransferScalarPerVector_NWaveNPerXdl",
|
||||
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
|
||||
{"MaskOutUpperTriangle", std::to_string(this->mask_out_upper_triangle)},
|
||||
};
|
||||
|
||||
return Solution{InterpolateString(DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate, values),
|
||||
std::move(values)};
|
||||
}
|
||||
|
||||
} // namespace device_batched_gemm_softmax_gemm
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
@@ -62,6 +62,12 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi)
|
||||
// accounts for all possible combinations of Row/Col major
|
||||
static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; }
|
||||
|
||||
// clang-format off
|
||||
// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1,
|
||||
|
||||
// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1>
|
||||
// clang-format on
|
||||
|
||||
// Hard-code tuning parameters in modularized fashion, string them together into a vector of
|
||||
// instances
|
||||
std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
@@ -83,6 +89,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
{ 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1},
|
||||
{ 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1},
|
||||
{ 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1},
|
||||
// Irregular tile
|
||||
{ 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -100,6 +108,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
// Irregular tile
|
||||
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -109,15 +119,17 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM|
|
||||
// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| |
|
||||
// | | | | | | |
|
||||
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
|
||||
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
|
||||
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
|
||||
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
|
||||
// Irregular tile
|
||||
{ S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
|
||||
// clang-format on
|
||||
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
|
||||
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
|
||||
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
|
||||
{S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
{S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
|
||||
};
|
||||
|
||||
std::vector<operation::BlockTransferDesc> b_block_descriptions_rowmajor = {
|
||||
@@ -134,6 +146,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
{ S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1},
|
||||
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
|
||||
{ S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1},
|
||||
// Irregular tile
|
||||
{ S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -151,6 +165,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
{ S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
{ S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1},
|
||||
// Irregular tile
|
||||
{ S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -167,6 +183,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
{ 1, 1},
|
||||
{ 1, 1},
|
||||
{ 1, 1},
|
||||
{ 1, 1},
|
||||
{ 1, 1},
|
||||
// clang-format on
|
||||
};
|
||||
@@ -185,6 +202,8 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
{ S<1, 16, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
{ S<1, 32, 1, 8>, 8},
|
||||
// Irregular tile
|
||||
{ S<1, 16, 1, 4>, 1},
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
@@ -199,33 +218,44 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
|
||||
assert(tile_descriptions.size() == cshuffle_descriptions.size());
|
||||
assert(tile_descriptions.size() == c_block_descriptions.size());
|
||||
|
||||
// Put all values together into a single operation > store into the result vector
|
||||
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
|
||||
const std::vector<std::tuple<LoopScheduler, PipelineVersion>> scheduler_pipeline_descriptions =
|
||||
{
|
||||
{LoopScheduler::Default, PipelineVersion::v1},
|
||||
{LoopScheduler::Interwave, PipelineVersion::v1},
|
||||
{LoopScheduler::Default, PipelineVersion::v2},
|
||||
};
|
||||
for(auto [loop_scheduler, pipeline_version] : scheduler_pipeline_descriptions)
|
||||
{
|
||||
Operation_Xdl_CShuffle x;
|
||||
x.tile_desc = tile_descriptions[i];
|
||||
x.a_block_transfer = a_block_descriptions[i];
|
||||
x.b_block_transfer = b_block_descriptions[i];
|
||||
x.cshuffle = cshuffle_descriptions[i];
|
||||
x.c_block_transfer = c_block_descriptions[i];
|
||||
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
|
||||
x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
|
||||
x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)};
|
||||
x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) {
|
||||
return TensorDesc{dt, ToLayout(trans)};
|
||||
});
|
||||
x.a_elem_op = prob.AElementOp;
|
||||
x.b_elem_op = prob.BElementOp;
|
||||
x.cde_elem_op = prob.CDEElementOp;
|
||||
x.gemm_specialization = GetGemmSpec(prob.M,
|
||||
prob.N,
|
||||
prob.K,
|
||||
x.tile_desc.m_per_block,
|
||||
x.tile_desc.n_per_block,
|
||||
x.tile_desc.k_per_block);
|
||||
x.update_prologue(prologue);
|
||||
x.update_epilogue(epilogue);
|
||||
result.push_back(x);
|
||||
// Put all values together into a single operation > store into the result vector
|
||||
for(std::size_t i = 0; i < tile_descriptions.size(); i++)
|
||||
{
|
||||
Operation_Xdl_CShuffle x;
|
||||
x.tile_desc = tile_descriptions[i];
|
||||
x.a_block_transfer = a_block_descriptions[i];
|
||||
x.b_block_transfer = b_block_descriptions[i];
|
||||
x.cshuffle = cshuffle_descriptions[i];
|
||||
x.c_block_transfer = c_block_descriptions[i];
|
||||
x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)};
|
||||
x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)};
|
||||
x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)};
|
||||
x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) {
|
||||
return TensorDesc{dt, ToLayout(trans)};
|
||||
});
|
||||
x.a_elem_op = prob.AElementOp;
|
||||
x.b_elem_op = prob.BElementOp;
|
||||
x.cde_elem_op = prob.CDEElementOp;
|
||||
x.gemm_specialization = GetGemmSpec(prob.M,
|
||||
prob.N,
|
||||
prob.K,
|
||||
x.tile_desc.m_per_block,
|
||||
x.tile_desc.n_per_block,
|
||||
x.tile_desc.k_per_block);
|
||||
x.loop_scheduler = loop_scheduler;
|
||||
x.pipeline_version = pipeline_version;
|
||||
x.update_prologue(prologue);
|
||||
x.update_epilogue(epilogue);
|
||||
result.push_back(x);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -263,7 +293,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate =
|
||||
"${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, "
|
||||
"${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, "
|
||||
"${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, "
|
||||
"${CDEBlockTransferScalarPerVector_NPerBlock}>";
|
||||
"${CDEBlockTransferScalarPerVector_NPerBlock}, ${LoopScheduler}, ${PipelineVersion}>";
|
||||
|
||||
// use hardcoded instances from vector of operations to substitute values into instance template
|
||||
Solution Operation_Xdl_CShuffle::ToSolution() const
|
||||
@@ -336,6 +366,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
|
||||
this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl},
|
||||
{"CDEBlockTransferScalarPerVector_NPerBlock",
|
||||
std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)},
|
||||
{"LoopScheduler", ToString(this->loop_scheduler)},
|
||||
{"PipelineVersion", ToString(this->pipeline_version)},
|
||||
};
|
||||
|
||||
return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values),
|
||||
|
||||
@@ -59,6 +59,26 @@ std::string ToString(GemmType gt)
|
||||
throw std::runtime_error("Incorrect gemm type");
|
||||
}
|
||||
|
||||
std::string ToString(LoopScheduler ls)
|
||||
{
|
||||
switch(ls)
|
||||
{
|
||||
case LoopScheduler::Default: return "ck::LoopScheduler::Default";
|
||||
case LoopScheduler::Interwave: return "ck::LoopScheduler::Interwave";
|
||||
}
|
||||
throw std::runtime_error("Incorrect LoopScheduler type");
|
||||
}
|
||||
|
||||
std::string ToString(PipelineVersion pv)
|
||||
{
|
||||
switch(pv)
|
||||
{
|
||||
case PipelineVersion::v1: return "ck::PipelineVersion::v1";
|
||||
case PipelineVersion::v2: return "ck::PipelineVersion::v2";
|
||||
}
|
||||
throw std::runtime_error("Incorrect PipelineVersion type");
|
||||
}
|
||||
|
||||
std::string SequenceStr(const std::vector<int>& v)
|
||||
{
|
||||
return "ck::Sequence<" +
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace rtc {
|
||||
|
||||
|
||||
@@ -506,6 +506,14 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
if receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'bias']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_dq_dk_dv_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
@@ -801,4 +809,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
|
||||
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
|
||||
|
||||
@@ -487,13 +487,20 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
if kernel_filter != None:
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if receipt == 2:
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
if receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
|
||||
@@ -103,7 +103,8 @@ if __name__ == "__main__":
|
||||
required=False,
|
||||
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
|
||||
" 1: generate more instance to cover all hdim\n" + \
|
||||
" 2: Only generate instance for Flash attention integration"
|
||||
" 2: Only generate instance for Flash attention integration\n" + \
|
||||
" 4: Only generate instance for PyTorch integration"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -82,8 +82,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenGemmShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType,
|
||||
typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
@@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
|
||||
float ave_time = gemm_calc<ADataType, BDataType, AccDataType, CDataType,
|
||||
ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
float ave_time =
|
||||
gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
|
||||
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
<< " A_Layout =" << ALayout::name
|
||||
<< " B_Layout =" << BLayout::name
|
||||
<< " C_Layout =" << CLayout::name
|
||||
<< " A Type = " << DataTypeTraits<ADataType>::name
|
||||
<< " B Type = " << DataTypeTraits<BDataType>::name
|
||||
<< " C Type = " << DataTypeTraits<CDataType>::name
|
||||
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
<< " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
|
||||
<< " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
|
||||
<< " B Type = " << DataTypeTraits<BDataType>::name
|
||||
<< " C Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
@@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
@@ -133,19 +135,19 @@ int run_gemm_example_with_layouts(int argc,
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_gemm<ADataType, BDataType, AccDataType, CDataType,
|
||||
ALayout, BLayout, CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
invoke_gemm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
@@ -160,9 +162,9 @@ int run_gemm_example_with_layouts(int argc,
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
|
||||
(K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
@@ -171,7 +173,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
@@ -218,9 +220,9 @@ int run_gemm_example_with_layouts(int argc,
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
|
||||
(K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_gpu_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
@@ -229,7 +231,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
|
||||
@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("k", "4", "topk")
|
||||
.insert("unit", "32", "unit_size")
|
||||
.insert("moe_buf_size", "0", "moe_buf_size")
|
||||
.insert("local_eid",
|
||||
"-1",
|
||||
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
|
||||
"please make sure eid is in ascending order!")
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "when set to 1 it will print kernel name")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
|
||||
int max_output_ids =
|
||||
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
|
||||
|
||||
@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
return false;
|
||||
}
|
||||
|
||||
bool local_expert_masking = args.get_str("local_eid") != "-1";
|
||||
auto local_expert_masking_host = [&]() {
|
||||
if(local_expert_masking)
|
||||
{
|
||||
auto local_eid = args.get_int_vec("local_eid");
|
||||
// std::vector<int> v_ {num_experts, 0};
|
||||
ck_tile::HostTensor<IndexType> v_{{num_experts}};
|
||||
v_.SetZero();
|
||||
for(auto eid : local_eid)
|
||||
{
|
||||
if(eid >= num_experts)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"local_eid larger than number of expert, please check");
|
||||
}
|
||||
v_.mData[eid] = 1;
|
||||
}
|
||||
return v_;
|
||||
}
|
||||
else
|
||||
// return std::vector<int>{};
|
||||
return ck_tile::HostTensor<IndexType>{{1}};
|
||||
}();
|
||||
|
||||
// tokens already considered batch size
|
||||
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
|
||||
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
|
||||
@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
sorted_expert_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem local_expert_masking_dev(
|
||||
local_expert_masking_host.get_element_space_size_in_bytes());
|
||||
|
||||
topk_ids_dev.ToDevice(topk_ids_host.data());
|
||||
weights_dev.ToDevice(weights_host.data());
|
||||
@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
{
|
||||
moe_buf_dev.ToDevice(moe_buf_host.data());
|
||||
}
|
||||
if(local_expert_masking)
|
||||
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
|
||||
|
||||
moe_sorting_trait trait{index_prec, weight_prec};
|
||||
moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking};
|
||||
|
||||
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
|
||||
weights_dev.GetDeviceBuffer(),
|
||||
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
sorted_ids_dev.GetDeviceBuffer(),
|
||||
sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(),
|
||||
@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
warmup,
|
||||
repeat};
|
||||
auto ms = moe_sorting(trait, karg, sc);
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ",
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ",
|
||||
index_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
ms);
|
||||
topk);
|
||||
|
||||
if(local_expert_masking)
|
||||
{
|
||||
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
|
||||
}
|
||||
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
else
|
||||
printf("ms:%f, ", ms);
|
||||
fflush(stdout);
|
||||
if(ms < 0)
|
||||
{
|
||||
@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
int32_t ref_total_tokens_post_pad = 0;
|
||||
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
|
||||
weights_host,
|
||||
local_expert_masking_host,
|
||||
sorted_ids_ref,
|
||||
sorted_weights_ref,
|
||||
sorted_expert_ids_ref,
|
||||
ref_total_tokens_post_pad,
|
||||
num_experts,
|
||||
unit_size);
|
||||
unit_size,
|
||||
local_expert_masking);
|
||||
rtn &= ck_tile::check_err(
|
||||
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
|
||||
rtn &= ck_tile::check_err(sorted_weights_host,
|
||||
@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
|
||||
}
|
||||
rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
|
||||
printf("total_tokens_post_pad:%d(%d), ",
|
||||
ref_total_tokens_post_pad,
|
||||
sorted_id_cnt_host.mData[0]);
|
||||
}
|
||||
|
||||
printf("valid:%s\n", rtn ? "y" : "n");
|
||||
printf("valid:%s", rtn ? "y" : "n");
|
||||
fflush(stdout);
|
||||
if(!rtn)
|
||||
printf(", (%d)", seed);
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
@@ -3,6 +3,12 @@
|
||||
|
||||
#include "moe_sorting_api.hpp"
|
||||
|
||||
#ifndef MOE_SORTING_USE_EX_KERNEL
|
||||
#define MOE_SORTING_USE_EX_KERNEL 1
|
||||
#endif
|
||||
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
|
||||
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr ck_tile::index_t expert_tile = expert_tile_; \
|
||||
@@ -17,6 +23,67 @@
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#else
|
||||
|
||||
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
|
||||
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
|
||||
constexpr bool sub_token_onshot = sub_token_onshot_; \
|
||||
constexpr bool local_expert_masking = local_expert_masking_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
|
||||
if(is_local_expert_masking) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
#define MOE_SORTING_DISPATCH(unroll_num_) \
|
||||
if(a.num_experts <= 8) \
|
||||
{ \
|
||||
@@ -38,11 +105,13 @@
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
|
||||
}
|
||||
#endif
|
||||
|
||||
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
if(a.num_experts > 127)
|
||||
{
|
||||
printf("lds size exceed, only support experts <127 \n");
|
||||
@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
MOE_SORTING_DISPATCH(4);
|
||||
}
|
||||
}
|
||||
#else
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
|
||||
auto sub_token_ = r_ - 2;
|
||||
r_ = (r_ - 2) / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
bool is_local_expert_masking = t.local_expert_masking;
|
||||
(void)c_;
|
||||
|
||||
MOE_SORTING_DISPATCH_EMASK_(r_);
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
#endif
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -10,7 +10,8 @@
|
||||
struct moe_sorting_trait
|
||||
{
|
||||
std::string index_type;
|
||||
std::string weight_type; // currently always float
|
||||
std::string weight_type; // currently always float
|
||||
bool local_expert_masking; // if mask experts as local expert
|
||||
};
|
||||
|
||||
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
|
||||
@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11
|
||||
$EXE -t=1 -e=1 -k=1
|
||||
$EXE -t=99 -e=2 -k=1
|
||||
$EXE -t=333 -e=99 -k=13
|
||||
$EXE -t=11 -e=256 -k=5
|
||||
$EXE -t=64 -e=455 -k=8
|
||||
$EXE -t=777 -e=802 -k=99
|
||||
$EXE -t=4097 -e=906 -k=51
|
||||
$EXE -t=128 -e=32 -k=5 -moe_buf_size=262144
|
||||
$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11
|
||||
$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19
|
||||
$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33
|
||||
$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129
|
||||
|
||||
@@ -42,7 +42,7 @@ summary of the key design of this fused-moe operator:
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
|
||||
// * this could be larger than actual, since actual tokens are on GPU
|
||||
//
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
|
||||
|
||||
@@ -3,6 +3,12 @@
|
||||
|
||||
#include "fused_moesorting.hpp"
|
||||
|
||||
#ifndef MOE_SORTING_USE_EX_KERNEL
|
||||
#define MOE_SORTING_USE_EX_KERNEL 1
|
||||
#endif
|
||||
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
|
||||
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr ck_tile::index_t expert_tile = expert_tile_; \
|
||||
@@ -17,6 +23,24 @@
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#else
|
||||
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_) \
|
||||
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
|
||||
constexpr bool sub_token_onshot = sub_token_onshot_; \
|
||||
using ms_problem = \
|
||||
ck_tile::MoeSortingProblemEx<index_t, ms_weight_type, sub_token_tile, sub_token_onshot>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#endif
|
||||
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
#define MOE_SORTING_DISPATCH(unroll_num_) \
|
||||
if(a.num_experts <= 8) \
|
||||
{ \
|
||||
@@ -38,11 +62,13 @@
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
|
||||
}
|
||||
#endif
|
||||
|
||||
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
#if !MOE_SORTING_USE_EX_KERNEL
|
||||
if(a.num_experts > 127)
|
||||
{
|
||||
printf("lds size exceed, only support experts <127 \n");
|
||||
@@ -83,6 +109,54 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
MOE_SORTING_DISPATCH(4);
|
||||
}
|
||||
}
|
||||
#else
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
|
||||
auto sub_token_ = r_ - 2;
|
||||
r_ = (r_ - 2) / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
(void)c_;
|
||||
if(is_sub_token_onshot)
|
||||
{
|
||||
if(r_ % 8 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(8, true);
|
||||
}
|
||||
else if(r_ % 4 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(4, true);
|
||||
}
|
||||
else if(r_ % 2 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(2, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(1, true);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(r_ % 8 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(8, false);
|
||||
}
|
||||
else if(r_ % 4 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(4, false);
|
||||
}
|
||||
else if(r_ % 2 == 0)
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(2, false);
|
||||
}
|
||||
else
|
||||
{
|
||||
MOE_SORTING_DISPATCH_(1, false);
|
||||
}
|
||||
}
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
#endif
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -79,8 +79,11 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenGemmShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
@@ -212,7 +212,7 @@ int run_batched_gemm_example_with_layouts(int argc,
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
|
||||
@@ -118,7 +118,7 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
std::cout << "Launching kernel: " << GroupedGemmKernel::GetName() << " with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
|
||||
@@ -202,7 +202,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
|
||||
@@ -610,6 +610,96 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return true;
|
||||
}
|
||||
|
||||
static constexpr bool
|
||||
IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_)
|
||||
{
|
||||
// check vector load/store
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
// check vector load of A
|
||||
if constexpr(is_same_v<ALayout, Row>)
|
||||
{
|
||||
if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col>)
|
||||
{
|
||||
if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of B
|
||||
if constexpr(is_same_v<BLayout, Row>)
|
||||
{
|
||||
if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<BLayout, Col>)
|
||||
{
|
||||
if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of B1
|
||||
if constexpr(is_same_v<B1Layout, Row>)
|
||||
{
|
||||
if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<B1Layout, Col>)
|
||||
{
|
||||
if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check vector load of C
|
||||
if constexpr(is_same_v<CLayout, Row>)
|
||||
{
|
||||
if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<CLayout, Col>)
|
||||
{
|
||||
if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
@@ -624,29 +714,12 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
|
||||
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
|
||||
const auto b_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
|
||||
const auto b1_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
|
||||
const auto c_extent_lowest =
|
||||
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
arg.block_2_ctile_map_) and
|
||||
IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
@@ -764,6 +837,268 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
template <class ADesc, class BDesc, class B1Desc, class CDesc>
|
||||
struct Descriptor
|
||||
{
|
||||
template <class AGridDescriptor>
|
||||
static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc)
|
||||
{
|
||||
const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc);
|
||||
|
||||
const auto M = a_grid_desc_m_k.GetLength(I0);
|
||||
const auto K = a_grid_desc_m_k.GetLength(I1);
|
||||
|
||||
const auto AK0 = K / AK1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_grid_desc_m_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <class BGridDescriptor>
|
||||
static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc)
|
||||
{
|
||||
const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc);
|
||||
|
||||
const auto N = b_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto BK0 = K / BK1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <class B1GridDescriptor>
|
||||
static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc)
|
||||
{
|
||||
const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc);
|
||||
|
||||
const auto N = b1_grid_desc_n_k.GetLength(I0);
|
||||
const auto K = b1_grid_desc_n_k.GetLength(I1);
|
||||
|
||||
const auto B1K0 = K / B1K1;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b1_grid_desc_n_k,
|
||||
make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)),
|
||||
make_pass_through_transform(N)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
template <class CGridDescriptor>
|
||||
static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc)
|
||||
{
|
||||
return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc);
|
||||
}
|
||||
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1(ADesc{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1(BDesc{}))>;
|
||||
using B1GridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(MakeB1GridDescriptor_BK0_N_BK1(B1Desc{}))>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(MakeCGridDescriptor_M_N(CDesc{}))>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
B1GridDesc_BK0_N_BK1,
|
||||
CGridDesc_M_N,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
Gemm1NPerBlock,
|
||||
Gemm1KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
B1K1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
Gemm1NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
true,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
true,
|
||||
BBlockLdsExtraN,
|
||||
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
B1BlockTransferThreadClusterArrangeOrder,
|
||||
B1BlockTransferSrcAccessOrder,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
B1BlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
matrix_padder.PadN,
|
||||
MaskOutUpperTriangle>;
|
||||
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1;
|
||||
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1;
|
||||
CGridDesc_M_N c_grid_desc_m_n;
|
||||
C0MatrixMask c0_matrix_mask;
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map;
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_descriptor_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op;
|
||||
BElementwiseOperation b_element_op;
|
||||
B1ElementwiseOperation b1_element_op;
|
||||
CElementwiseOperation c_element_op;
|
||||
|
||||
bool has_main_k_block_loop = true;
|
||||
bool is_valid = false;
|
||||
|
||||
constexpr Descriptor(ADesc a,
|
||||
BDesc b,
|
||||
B1Desc b1,
|
||||
CDesc c,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
B1ElementwiseOperation b1_element_op_,
|
||||
CElementwiseOperation c_element_op_)
|
||||
: a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)},
|
||||
b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)},
|
||||
b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)},
|
||||
c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)},
|
||||
block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)},
|
||||
c_grid_descriptor_mblock_mperblock_nblock_nperblock{
|
||||
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n)},
|
||||
has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
|
||||
c0_matrix_mask{c.GetLength(I1)},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
b1_element_op{b1_element_op_},
|
||||
c_element_op{c_element_op_},
|
||||
is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_m_n,
|
||||
block_2_ctile_map) and
|
||||
IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1),
|
||||
b_grid_desc_bk0_n_bk1.GetLength(I1),
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I0) *
|
||||
a_grid_desc_ak0_m_ak1.GetLength(I2),
|
||||
b1_grid_desc_bk0_n_bk1.GetLength(I1))}
|
||||
{
|
||||
}
|
||||
|
||||
constexpr bool IsValid() const { return is_valid; }
|
||||
};
|
||||
|
||||
template <class ADesc, class BDesc, class B1Desc, class CDesc>
|
||||
static constexpr auto
|
||||
make_descriptor(ADesc a,
|
||||
BDesc b,
|
||||
B1Desc b1,
|
||||
CDesc c,
|
||||
AElementwiseOperation a_element_op = AElementwiseOperation{},
|
||||
BElementwiseOperation b_element_op = BElementwiseOperation{},
|
||||
B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{},
|
||||
CElementwiseOperation c_element_op = CElementwiseOperation{})
|
||||
{
|
||||
return Descriptor<ADesc, BDesc, B1Desc, CDesc>(
|
||||
a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op);
|
||||
}
|
||||
|
||||
template <class Desc>
|
||||
__device__ static void Run(const Desc& desc,
|
||||
const float scale,
|
||||
const ADataType* __restrict__ p_a_grid,
|
||||
const ADataType* __restrict__ p_b_grid,
|
||||
const ADataType* __restrict__ p_b1_grid,
|
||||
CDataType* __restrict__ p_c_grid)
|
||||
{
|
||||
#ifndef __HIPCC_RTC__
|
||||
assert(desc.is_valid);
|
||||
#endif
|
||||
__shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
AccElementwiseOperation acc_element_op{scale};
|
||||
|
||||
if(desc.has_main_k_block_loop)
|
||||
{
|
||||
Desc::GridwiseGemm::template Run<true>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_b1_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
acc_element_op,
|
||||
desc.b1_element_op,
|
||||
desc.c_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.b1_grid_desc_bk0_n_bk1,
|
||||
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_ctile_map,
|
||||
desc.c0_matrix_mask);
|
||||
}
|
||||
else
|
||||
{
|
||||
Desc::GridwiseGemm::template Run<false>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_b1_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
desc.a_element_op,
|
||||
desc.b_element_op,
|
||||
acc_element_op,
|
||||
desc.b1_element_op,
|
||||
desc.c_element_op,
|
||||
desc.a_grid_desc_ak0_m_ak1,
|
||||
desc.b_grid_desc_bk0_n_bk1,
|
||||
desc.b1_grid_desc_bk0_n_bk1,
|
||||
desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock,
|
||||
desc.block_2_ctile_map,
|
||||
desc.c0_matrix_mask);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -1495,10 +1495,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
// if workspace is not allocated
|
||||
if(!arg.p_workspace_)
|
||||
{
|
||||
std::cerr << "Warning: Workspace for "
|
||||
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not "
|
||||
"allocated, use SetWorkSpacePointer."
|
||||
<< std::endl;
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Warning: Workspace for "
|
||||
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not "
|
||||
"allocated, use SetWorkSpacePointer."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if(!ck::is_xdl_supported())
|
||||
|
||||
@@ -27,12 +27,12 @@
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/numeric/null_type.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/tensor/buffer_view.hpp"
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/host/arg_parser.hpp"
|
||||
#include "ck_tile/host/check_err.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
|
||||
122
include/ck_tile/host/concat.hpp
Normal file
122
include/ck_tile/host/concat.hpp
Normal file
@@ -0,0 +1,122 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T>
|
||||
struct IsCharArray : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <std::size_t N>
|
||||
struct IsCharArray<char[N]> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <std::size_t N>
|
||||
struct IsCharArray<const char[N]> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <std::size_t N>
|
||||
struct IsCharArray<char (&)[N]> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <std::size_t N>
|
||||
struct IsCharArray<const char (&)[N]> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
inline constexpr bool AllConvertibleToStringView = ((std::is_convertible_v<Ts, std::string_view> ||
|
||||
IsCharArray<Ts>::value ||
|
||||
std::is_same_v<Ts, char>)&&...);
|
||||
|
||||
template <typename... Ts>
|
||||
[[nodiscard]] auto concat(const Ts&... xs)
|
||||
-> std::enable_if_t<!AllConvertibleToStringView<Ts...>, std::string>
|
||||
{
|
||||
using ::operator<<;
|
||||
thread_local std::ostringstream oss;
|
||||
oss.str("");
|
||||
|
||||
(oss << ... << xs);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
template <std::size_t N>
|
||||
[[nodiscard]] constexpr inline std::size_t getSize(char (&)[N]) noexcept
|
||||
{
|
||||
return N;
|
||||
}
|
||||
|
||||
template <std::size_t N>
|
||||
[[nodiscard]] constexpr inline std::size_t getSize(const char (&)[N]) noexcept
|
||||
{
|
||||
return N;
|
||||
}
|
||||
|
||||
[[nodiscard]] constexpr inline std::size_t getSize(const char* s) noexcept
|
||||
{
|
||||
const char* end = s;
|
||||
while(*end++ != 0) {}
|
||||
return end - s - 1;
|
||||
}
|
||||
|
||||
[[nodiscard]] constexpr inline std::size_t getSize(const char&) noexcept { return 1; }
|
||||
|
||||
[[nodiscard]] inline std::size_t getSize(const std::string& s) noexcept { return s.size(); }
|
||||
|
||||
[[nodiscard]] constexpr inline std::size_t getSize(const std::string_view& s) noexcept
|
||||
{
|
||||
return s.size();
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
auto concatInto(std::string& result, const Ts&... xs)
|
||||
-> std::enable_if_t<AllConvertibleToStringView<Ts...>, void>
|
||||
{
|
||||
const std::size_t space = (1 + ... + getSize(xs));
|
||||
result.reserve(result.size() + space);
|
||||
((result += xs), ...);
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
[[nodiscard]] auto concat(const Ts&... xs)
|
||||
-> std::enable_if_t<AllConvertibleToStringView<Ts...>, std::string>
|
||||
{
|
||||
std::string result;
|
||||
concatInto(result, xs...);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Function for types convertible to std::string_view
|
||||
template <typename Sep, typename First, typename... Rest>
|
||||
[[nodiscard]] auto concat(Sep sep, const First& first, const Rest&... rest)
|
||||
-> std::enable_if_t<AllConvertibleToStringView<First, Rest...>, std::string>
|
||||
{
|
||||
std::string result;
|
||||
result += first;
|
||||
((result += sep, result += rest), ...);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Function for other types
|
||||
template <typename Sep, typename First, typename... Rest>
|
||||
[[nodiscard]] auto concat(Sep sep, const First& first, const Rest&... rest)
|
||||
-> std::enable_if_t<!AllConvertibleToStringView<First, Rest...>, std::string>
|
||||
{
|
||||
using ::operator<<;
|
||||
thread_local std::ostringstream oss;
|
||||
oss.str("");
|
||||
oss << first;
|
||||
((oss << sep << rest), ...);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -14,12 +14,15 @@ namespace ck_tile {
|
||||
template <typename WeightType, typename IndexType = index_t>
|
||||
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
const HostTensor<WeightType>& weights,
|
||||
const HostTensor<IndexType>& local_expert_mask,
|
||||
HostTensor<IndexType>& p_sorted_token_ids,
|
||||
HostTensor<WeightType>& sorted_weight,
|
||||
HostTensor<IndexType>& sorted_expert_ids,
|
||||
index_t& unit_cnt,
|
||||
const index_t experts,
|
||||
const index_t unit_size)
|
||||
const index_t unit_size,
|
||||
bool local_expert_masking,
|
||||
bool skip_experts_with_zero_token = true)
|
||||
{
|
||||
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
|
||||
const index_t topk = topk_ids.mDesc.get_lengths()[1];
|
||||
@@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
#endif
|
||||
std::vector<std::vector<WeightType>> expert_token_weights(
|
||||
experts, std::vector<WeightType>(unit_size, 0));
|
||||
// count number of unit-size slices in this expert
|
||||
std::vector<IndexType> expert_slices(experts, 1);
|
||||
// count the tokens used in this expert
|
||||
std::vector<IndexType> expert_slice_idxs(experts, 0);
|
||||
// TODO: above 2 buffer seems duplicated
|
||||
|
||||
for(index_t t = 0; t < num_token; t++)
|
||||
{
|
||||
@@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
IndexType* out_tokens = p_sorted_token_ids.data();
|
||||
WeightType* out_weights = sorted_weight.data();
|
||||
IndexType* out_expert_id = sorted_expert_ids.data();
|
||||
int curr_expert_id = 0;
|
||||
for(index_t e = 0; e < experts; e++)
|
||||
{
|
||||
if(local_expert_masking)
|
||||
{
|
||||
if(local_expert_mask(e) == 0)
|
||||
continue;
|
||||
}
|
||||
if(skip_experts_with_zero_token)
|
||||
{
|
||||
if(expert_slice_idxs[e] == 0)
|
||||
{
|
||||
curr_expert_id++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
|
||||
out_tokens += expert_slices[e] * unit_size;
|
||||
memcpy(out_weights,
|
||||
@@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
|
||||
for(index_t s = 0; s < expert_slices[e]; s++)
|
||||
{
|
||||
out_expert_id[s] = e;
|
||||
out_expert_id[s] = curr_expert_id;
|
||||
unit_cnt++;
|
||||
}
|
||||
out_expert_id += expert_slices[e];
|
||||
curr_expert_id++;
|
||||
}
|
||||
unit_cnt *= unit_size;
|
||||
return;
|
||||
|
||||
@@ -10,3 +10,4 @@
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -9,3 +9,4 @@
|
||||
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -5,3 +5,4 @@
|
||||
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
34
include/ck_tile/ops/common/utils.hpp
Normal file
34
include/ck_tile/ops/common/utils.hpp
Normal file
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct typeToStr;
|
||||
template <> struct typeToStr<float> { static constexpr const char * name = "fp32"; };
|
||||
template <> struct typeToStr<fp16_t> { static constexpr const char * name = "fp16"; };
|
||||
template <> struct typeToStr<bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct typeToStr<fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct typeToStr<bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
template <> struct typeToStr<int8_t> { static constexpr const char * name = "int8"; };
|
||||
// clang-format on
|
||||
|
||||
template <typename ADataType_, typename BDataType_>
|
||||
std::string gemm_prec_str()
|
||||
{
|
||||
std::string base_str = std::string(typeToStr<ADataType_>::name);
|
||||
if(!std::is_same_v<ADataType_, BDataType_>)
|
||||
{
|
||||
base_str += "_" + std::string(typeToStr<BDataType_>::name);
|
||||
}
|
||||
return base_str;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,3 +6,4 @@
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -8,3 +8,4 @@
|
||||
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -9,3 +9,4 @@
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -44,3 +44,4 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
|
||||
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
|
||||
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
|
||||
@@ -14,6 +15,6 @@
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
|
||||
// * this could be larger than actual, since actual tokens are on GPU
|
||||
//
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
|
||||
|
||||
@@ -15,6 +15,10 @@ namespace ck_tile {
|
||||
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
|
||||
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
|
||||
|
||||
#ifndef MOE_SORTING_USE_EX_KERNEL
|
||||
#define MOE_SORTING_USE_EX_KERNEL 1
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
// [indexing implementation-1]
|
||||
// using M_a as constexpr block_size to partition all tokens into different slices
|
||||
@@ -28,7 +32,7 @@ namespace ck_tile {
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
|
||||
// * this could be larger than actual, since actual tokens are on GPU
|
||||
//
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
|
||||
@@ -55,6 +59,34 @@ namespace ck_tile {
|
||||
// num_tokens_post_padded_ptr : [28]
|
||||
// num_sorted_tiles_ptr : [7]
|
||||
//
|
||||
// skip_experts_with_zero_tokens(SkipExpertsWithZeroTokens)
|
||||
// if enabled, the expert with no tokens will be skipped, in stead of padding to at least 1 unit_size(M_a)
|
||||
//
|
||||
// (pack below tensor, skip element marked with `-`)
|
||||
// Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
|
||||
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
|
||||
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
|
||||
//
|
||||
//
|
||||
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 5]
|
||||
// num_tokens_post_padded_ptr : [24]
|
||||
//
|
||||
// * local_expert_mask : indicate local expert mask used on current GPU (used for EP case)
|
||||
// and modify the output expert-ID, because we will only have enbaled expert on specific GPU.
|
||||
// we call expert input to this kernel as "global expert id", output as "local expert id"
|
||||
//
|
||||
// * local_expert_mask : [1, 0, 1, 1, 0, 1] (mask out expert-id=1, 4)
|
||||
//
|
||||
// (pack below tensor, skip element marked with `-`)
|
||||
// Y Y Y Y - - - - Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
|
||||
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
|
||||
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
|
||||
//
|
||||
// sorted_expert_ids_ptr : [0, 1, 2, 2, 3] (note original it was exper-id= 0, 2, 3, 5, but we produce "local expert id")
|
||||
// num_tokens_post_padded_ptr : [20]
|
||||
//
|
||||
// * different from vLLM
|
||||
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
|
||||
// 2)need sorted_weight_ptr
|
||||
@@ -67,10 +99,80 @@ namespace ck_tile {
|
||||
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
|
||||
//
|
||||
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
|
||||
|
||||
|
||||
CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int num_tokens_, int num_experts_)
|
||||
{
|
||||
/* num_experts + 1
|
||||
* +--------------------------------------+
|
||||
* | |
|
||||
* | |
|
||||
* | | * -> sub-tokens
|
||||
* | |
|
||||
* | |
|
||||
* +--------------------------------------+
|
||||
* | | 2 -> cumsum buffer
|
||||
* +--------------------------------------+
|
||||
*
|
||||
*/
|
||||
int smem_cols = num_experts_ + 1; // usually experts is power of 2. padding here
|
||||
int smem_rows = [&](){
|
||||
index_t target_occupancy_ = 2;
|
||||
constexpr index_t total_ = 65536 / sizeof(int);
|
||||
constexpr index_t sub_unroll = 8;
|
||||
constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt
|
||||
// at lease 2 lines, one for sub_token unroll, one for cumsum
|
||||
// should be enough
|
||||
if ((total_ / target_occupancy_) < ((cumsum_bufs+sub_unroll) * smem_cols)) {
|
||||
if ((total_ / 1) < ((cumsum_bufs+sub_unroll) * smem_cols))
|
||||
throw std::runtime_error("too many num_experts, can't allocate smem");
|
||||
target_occupancy_ = 1;
|
||||
}
|
||||
int r = total_ / target_occupancy_ / smem_cols;
|
||||
|
||||
// round to sub_unroll multipl
|
||||
int r_for_sub_token = r - cumsum_bufs;
|
||||
r_for_sub_token = min(r_for_sub_token, num_tokens_);
|
||||
r_for_sub_token = (r_for_sub_token + sub_unroll - 1) / sub_unroll * sub_unroll;
|
||||
r_for_sub_token = max(r_for_sub_token, 1);
|
||||
|
||||
if(r_for_sub_token > 1)
|
||||
{
|
||||
int r_unroll_ = r_for_sub_token / sub_unroll;
|
||||
|
||||
|
||||
// round to 1x/2x/4x/8x number of sub_unroll
|
||||
int clz_ = __builtin_clz(r_unroll_); // 0b1:31 0b2:30, 0b3:30, 0b4:29
|
||||
int mask_ = (1 << (31 - clz_)) - 1;
|
||||
|
||||
|
||||
mask_ = mask_ > 0b111 ? 0b111 : mask_; //clamp to 8x at most
|
||||
mask_ = ~mask_;
|
||||
//printf("r_unroll_:%d, clz:%d, mask:%x\n", r_unroll_, clz_, mask_); fflush(stdout);
|
||||
|
||||
r_for_sub_token = (r_unroll_ & mask_) * sub_unroll;
|
||||
}
|
||||
|
||||
// final check
|
||||
if( (r_for_sub_token + cumsum_bufs * smem_cols * target_occupancy_ ) >= total_ ) {
|
||||
throw std::runtime_error("can't run this kernel, request LDS over size");
|
||||
}
|
||||
|
||||
return r_for_sub_token + cumsum_bufs;
|
||||
}();
|
||||
|
||||
// printf("r:%d, c:%d\n", smem_rows, smem_cols);
|
||||
|
||||
return ck_tile::make_tuple(smem_rows, smem_cols);
|
||||
}
|
||||
|
||||
struct MoeSortingHostArgs
|
||||
{
|
||||
const void* p_topk_ids; // [token, topk]
|
||||
const void* p_weights; // [token, topk]
|
||||
|
||||
const void* p_local_expert_mask;
|
||||
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
@@ -101,6 +203,7 @@ struct MoeSortingKernel
|
||||
{
|
||||
const void* p_topk_ids;
|
||||
const void* p_weights;
|
||||
const void* p_local_expert_mask;
|
||||
void* p_sorted_token_ids;
|
||||
void* p_sorted_weights;
|
||||
void* p_sorted_expert_ids;
|
||||
@@ -111,8 +214,11 @@ struct MoeSortingKernel
|
||||
index_t moe_buf_bytes;
|
||||
|
||||
index_t tokens_per_thread;
|
||||
index_t smem_rows;
|
||||
mdiv unit_size_mdiv;
|
||||
mdiv topk_mdiv;
|
||||
mdiv expert_mdiv;
|
||||
// mdiv sub_tokens_mdiv;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
@@ -123,15 +229,25 @@ struct MoeSortingKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
|
||||
{
|
||||
#if MOE_SORTING_USE_EX_KERNEL
|
||||
(void)h;
|
||||
return dim3(256);
|
||||
#else
|
||||
return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size()));
|
||||
#endif
|
||||
}
|
||||
|
||||
// in byte
|
||||
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
|
||||
{
|
||||
#if MOE_SORTING_USE_EX_KERNEL
|
||||
auto [smem_rows, smem_cols] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
|
||||
return smem_rows * smem_cols * sizeof(int);
|
||||
#else
|
||||
const auto blocks = BlockSize(h);
|
||||
// usually num_experts is power of 2, we pad 1 dword here for the row-size
|
||||
return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t);
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
@@ -139,6 +255,7 @@ struct MoeSortingKernel
|
||||
Kargs k;
|
||||
k.p_topk_ids = h.p_topk_ids;
|
||||
k.p_weights = h.p_weights;
|
||||
k.p_local_expert_mask = h.p_local_expert_mask;
|
||||
k.p_sorted_token_ids = h.p_sorted_token_ids;
|
||||
k.p_sorted_weights = h.p_sorted_weights;
|
||||
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
|
||||
@@ -152,10 +269,18 @@ struct MoeSortingKernel
|
||||
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
|
||||
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
|
||||
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
|
||||
k.smem_rows = [&](){
|
||||
auto [r_, c_] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
|
||||
(void) c_;
|
||||
return r_;
|
||||
}();
|
||||
k.expert_mdiv = mdiv{static_cast<uint32_t>(h.num_experts)};
|
||||
// k.sub_tokens_mdiv = mdiv{static_cast<uint32_t>(k.smem_rows - 1)};
|
||||
return k;
|
||||
}
|
||||
|
||||
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
|
||||
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
|
||||
// NOTE: wave_size need at least be 16!! dpp 16 is one row
|
||||
template <typename data_t, int wave_size>
|
||||
__device__ inline void wave_cumsum(data_t& thread_data) const
|
||||
{
|
||||
@@ -196,6 +321,40 @@ struct MoeSortingKernel
|
||||
bank_mask,
|
||||
bound_ctrl))); // row_shr:4
|
||||
}
|
||||
if constexpr(wave_size == 8) {
|
||||
|
||||
// wave-size=8 need one extra shift
|
||||
thread_data =
|
||||
reduce_op(thread_data,
|
||||
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
0x118,
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl))); // row_shr:8
|
||||
#if 0
|
||||
constexpr int bank_mask_0_7 = 0b1100;
|
||||
auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; };
|
||||
thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_update_dpp(0, /* old value */
|
||||
__builtin_bit_cast(int, thread_data),
|
||||
0x157,
|
||||
row_mask,
|
||||
bank_mask_0_7,
|
||||
bound_ctrl))// row_newbcast:7
|
||||
);
|
||||
#else
|
||||
data_t xxx =__builtin_bit_cast(data_t,
|
||||
__builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
|
||||
0x157,
|
||||
row_mask,
|
||||
bank_mask,
|
||||
bound_ctrl)); // row_newbcast:7
|
||||
|
||||
data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
|
||||
thread_data = thread_data - yyy;
|
||||
#endif
|
||||
|
||||
}
|
||||
if constexpr(wave_size > 8)
|
||||
{
|
||||
thread_data =
|
||||
@@ -224,6 +383,36 @@ struct MoeSortingKernel
|
||||
}
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
template <typename T, typename F, index_t wave_size_ = warpSize>
|
||||
__device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
// constexpr int reduce_stage = 6; // 1<<6=64
|
||||
// clang-format off
|
||||
constexpr int reduce_stage = [](){
|
||||
if constexpr(wave_size_ == 2) return 1;
|
||||
else if constexpr(wave_size_ == 4) return 2;
|
||||
else if constexpr(wave_size_ == 8) return 3;
|
||||
else if constexpr(wave_size_ == 16) return 4;
|
||||
else if constexpr(wave_size_ == 32) return 5;
|
||||
else if constexpr(wave_size_ == 64) return 6;
|
||||
else return 0;
|
||||
}();
|
||||
// clang-format on
|
||||
T v_local = local;
|
||||
#pragma unroll reduce_stage
|
||||
for(int i_stage = 0; i_stage < reduce_stage; i_stage++)
|
||||
{
|
||||
int src_lane = __lane_id() ^ (1 << i_stage);
|
||||
int32_t v_remote_tmp =
|
||||
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
|
||||
T v_remote = bit_cast<T>(v_remote_tmp);
|
||||
v_local = reduce_f(v_local, v_remote);
|
||||
}
|
||||
return v_local;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
|
||||
{
|
||||
return row * total_col + col;
|
||||
@@ -257,37 +446,37 @@ struct MoeSortingKernel
|
||||
index_t* shared_mem = reinterpret_cast<index_t*>(smem);
|
||||
|
||||
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
|
||||
index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts+1); // 1: (num_experts + 1)
|
||||
index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts + 1); // 1: (num_experts + 1)
|
||||
|
||||
for(int i = 0; i < num_experts; ++i)
|
||||
{
|
||||
tokens_cnts[calc_index(num_experts+1, tid + 1, i)] = 0;
|
||||
tokens_cnts[calc_index(num_experts + 1, tid + 1, i)] = 0;
|
||||
}
|
||||
|
||||
#pragma unroll Problem_::InternalLoadUnroll
|
||||
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
|
||||
{
|
||||
++tokens_cnts[calc_index(num_experts+1, tid + 1, topk_id[i])];
|
||||
++tokens_cnts[calc_index(num_experts + 1, tid + 1, topk_id[i])];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#if 1
|
||||
if(tid < num_experts)
|
||||
{
|
||||
tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0;
|
||||
tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0;
|
||||
index_t local_c[8];
|
||||
index_t prev_c = 0;
|
||||
// TODO: manually unroll. pragma unroll does not work well when we have dependency
|
||||
for(int i = 1; i <= static_cast<index_t>(blockDim.x); i+= 8)
|
||||
for(int i = 1; i <= static_cast<index_t>(blockDim.x); i += 8)
|
||||
{
|
||||
local_c[0] = tokens_cnts[calc_index(num_experts+1, i + 0, tid)];
|
||||
local_c[1] = tokens_cnts[calc_index(num_experts+1, i + 1, tid)];
|
||||
local_c[2] = tokens_cnts[calc_index(num_experts+1, i + 2, tid)];
|
||||
local_c[3] = tokens_cnts[calc_index(num_experts+1, i + 3, tid)];
|
||||
local_c[4] = tokens_cnts[calc_index(num_experts+1, i + 4, tid)];
|
||||
local_c[5] = tokens_cnts[calc_index(num_experts+1, i + 5, tid)];
|
||||
local_c[6] = tokens_cnts[calc_index(num_experts+1, i + 6, tid)];
|
||||
local_c[7] = tokens_cnts[calc_index(num_experts+1, i + 7, tid)];
|
||||
local_c[0] = tokens_cnts[calc_index(num_experts + 1, i + 0, tid)];
|
||||
local_c[1] = tokens_cnts[calc_index(num_experts + 1, i + 1, tid)];
|
||||
local_c[2] = tokens_cnts[calc_index(num_experts + 1, i + 2, tid)];
|
||||
local_c[3] = tokens_cnts[calc_index(num_experts + 1, i + 3, tid)];
|
||||
local_c[4] = tokens_cnts[calc_index(num_experts + 1, i + 4, tid)];
|
||||
local_c[5] = tokens_cnts[calc_index(num_experts + 1, i + 5, tid)];
|
||||
local_c[6] = tokens_cnts[calc_index(num_experts + 1, i + 6, tid)];
|
||||
local_c[7] = tokens_cnts[calc_index(num_experts + 1, i + 7, tid)];
|
||||
|
||||
local_c[0] += prev_c;
|
||||
local_c[1] += local_c[0];
|
||||
@@ -299,51 +488,57 @@ struct MoeSortingKernel
|
||||
local_c[7] += local_c[6];
|
||||
prev_c = local_c[7];
|
||||
|
||||
tokens_cnts[calc_index(num_experts+1, i + 0, tid)] = local_c[0];
|
||||
tokens_cnts[calc_index(num_experts+1, i + 1, tid)] = local_c[1];
|
||||
tokens_cnts[calc_index(num_experts+1, i + 2, tid)] = local_c[2];
|
||||
tokens_cnts[calc_index(num_experts+1, i + 3, tid)] = local_c[3];
|
||||
tokens_cnts[calc_index(num_experts+1, i + 4, tid)] = local_c[4];
|
||||
tokens_cnts[calc_index(num_experts+1, i + 5, tid)] = local_c[5];
|
||||
tokens_cnts[calc_index(num_experts+1, i + 6, tid)] = local_c[6];
|
||||
tokens_cnts[calc_index(num_experts+1, i + 7, tid)] = local_c[7];
|
||||
tokens_cnts[calc_index(num_experts + 1, i + 0, tid)] = local_c[0];
|
||||
tokens_cnts[calc_index(num_experts + 1, i + 1, tid)] = local_c[1];
|
||||
tokens_cnts[calc_index(num_experts + 1, i + 2, tid)] = local_c[2];
|
||||
tokens_cnts[calc_index(num_experts + 1, i + 3, tid)] = local_c[3];
|
||||
tokens_cnts[calc_index(num_experts + 1, i + 4, tid)] = local_c[4];
|
||||
tokens_cnts[calc_index(num_experts + 1, i + 5, tid)] = local_c[5];
|
||||
tokens_cnts[calc_index(num_experts + 1, i + 6, tid)] = local_c[6];
|
||||
tokens_cnts[calc_index(num_experts + 1, i + 7, tid)] = local_c[7];
|
||||
}
|
||||
}
|
||||
#else
|
||||
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic
|
||||
// TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future
|
||||
// heuristic
|
||||
{
|
||||
if(tid < num_experts)
|
||||
tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0;
|
||||
for(int i = 0; i < num_experts; i+=8) {
|
||||
tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0;
|
||||
for(int i = 0; i < num_experts; i += 8)
|
||||
{
|
||||
index_t local_c[8];
|
||||
#pragma unroll
|
||||
for(int j = 0; j < 8; j++) {
|
||||
local_c[j] = tokens_cnts[calc_index(num_experts+1, tid+1, i+j)];
|
||||
#pragma unroll
|
||||
for(int j = 0; j < 8; j++)
|
||||
{
|
||||
local_c[j] = tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int j = 0; j < 8; j++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < 8; j++)
|
||||
{
|
||||
wave_cumsum<int, 64>(local_c[j]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int j = 0; j < 8; j++) {
|
||||
tokens_cnts[calc_index(num_experts+1, tid+1, i+j)] = local_c[j];
|
||||
#pragma unroll
|
||||
for(int j = 0; j < 8; j++)
|
||||
{
|
||||
tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)] = local_c[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
if constexpr (Problem::ExpertTile == 0) {
|
||||
if constexpr(Problem::ExpertTile == 0)
|
||||
{
|
||||
if(tid == 0)
|
||||
{
|
||||
cumsum[0] = 0;
|
||||
for(int i = 1; i <= num_experts; ++i)
|
||||
{
|
||||
auto current_units = [&]() {
|
||||
index_t x_ = tokens_cnts[calc_index(num_experts+1, blockDim.x, i - 1)] +
|
||||
unit_size_mdiv.divisor - 1;
|
||||
index_t x_ = tokens_cnts[calc_index(num_experts + 1, blockDim.x, i - 1)] +
|
||||
unit_size_mdiv.divisor - 1;
|
||||
index_t y_ = unit_size_mdiv.div(x_);
|
||||
return max(y_, 1) * unit_size_mdiv.divisor;
|
||||
}();
|
||||
@@ -351,20 +546,24 @@ struct MoeSortingKernel
|
||||
}
|
||||
*p_total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
} else {
|
||||
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert)
|
||||
// for simplicity, not check experts here.
|
||||
int local_cnt = tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)];
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >=
|
||||
// expert) for simplicity, not check experts here.
|
||||
int local_cnt = tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)];
|
||||
int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
|
||||
int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
|
||||
int local_cumsum = padded_tokens_per_expert;
|
||||
int local_cumsum = padded_tokens_per_expert;
|
||||
wave_cumsum<int, 64>(local_cumsum);
|
||||
|
||||
if(tid == (num_experts - 1)) {
|
||||
cumsum[0] = 0;
|
||||
if(tid == (num_experts - 1))
|
||||
{
|
||||
cumsum[0] = 0;
|
||||
*p_total_tokens_post_pad = local_cumsum;
|
||||
}
|
||||
if(tid < num_experts) {
|
||||
if(tid < num_experts)
|
||||
{
|
||||
cumsum[tid + 1] = local_cumsum;
|
||||
}
|
||||
}
|
||||
@@ -373,7 +572,7 @@ struct MoeSortingKernel
|
||||
if(tid < num_experts)
|
||||
{
|
||||
int e_start = cumsum[tid];
|
||||
int e_end = cumsum[tid + 1];
|
||||
int e_end = cumsum[tid + 1];
|
||||
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
|
||||
{
|
||||
p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid;
|
||||
@@ -383,8 +582,8 @@ struct MoeSortingKernel
|
||||
#pragma unroll Problem_::InternalLoadUnroll
|
||||
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
|
||||
{
|
||||
index_t expert_id = topk_id[i];
|
||||
index_t local_cnt = tokens_cnts[calc_index(num_experts+1, tid, expert_id)];
|
||||
index_t expert_id = topk_id[i];
|
||||
index_t local_cnt = tokens_cnts[calc_index(num_experts + 1, tid, expert_id)];
|
||||
index_t rank_post_pad = local_cnt + cumsum[expert_id];
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
@@ -393,16 +592,17 @@ struct MoeSortingKernel
|
||||
#else
|
||||
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
|
||||
#endif
|
||||
p_sorted_weights[rank_post_pad] = weights[i];
|
||||
tokens_cnts[calc_index(num_experts+1, tid, expert_id)] = local_cnt+1;
|
||||
p_sorted_weights[rank_post_pad] = weights[i];
|
||||
tokens_cnts[calc_index(num_experts + 1, tid, expert_id)] = local_cnt + 1;
|
||||
}
|
||||
|
||||
if constexpr (Problem::ExpertTile == 0) {
|
||||
if constexpr(Problem::ExpertTile == 0)
|
||||
{
|
||||
const index_t prefill_token = topk_mdiv.div(numel);
|
||||
if(tid < num_experts)
|
||||
{
|
||||
index_t expert_offset =
|
||||
cumsum[tid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)];
|
||||
cumsum[tid] + tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)];
|
||||
index_t expert_end = cumsum[tid + 1];
|
||||
while(expert_offset < expert_end)
|
||||
{
|
||||
@@ -417,16 +617,19 @@ struct MoeSortingKernel
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
else
|
||||
{
|
||||
const index_t prefill_token = topk_mdiv.div(numel);
|
||||
// TODO: only support expert-tile like 8, 16, 32
|
||||
static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile;
|
||||
{
|
||||
index_t eid = tid / experts_per_wave;
|
||||
index_t expert_offset =
|
||||
cumsum[eid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, eid)] + tid % experts_per_wave;
|
||||
index_t eid = tid / experts_per_wave;
|
||||
index_t expert_offset = cumsum[eid] +
|
||||
tokens_cnts[calc_index(num_experts + 1, blockDim.x, eid)] +
|
||||
tid % experts_per_wave;
|
||||
index_t expert_end = cumsum[eid + 1];
|
||||
if(eid < num_experts) {
|
||||
if(eid < num_experts)
|
||||
{
|
||||
while(expert_offset < expert_end)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
@@ -436,10 +639,363 @@ struct MoeSortingKernel
|
||||
p_sorted_token_ids[expert_offset] = prefill_token;
|
||||
#endif
|
||||
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
|
||||
expert_offset+=experts_per_wave;
|
||||
expert_offset += experts_per_wave;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// only support index_t, and single pixel access
|
||||
struct simple_smem_indexer
|
||||
{
|
||||
index_t* smem;
|
||||
index_t row_stride;
|
||||
|
||||
// this is 2D
|
||||
CK_TILE_DEVICE simple_smem_indexer(index_t* smem_, index_t row_stride_)
|
||||
: smem(smem_), row_stride(row_stride_)
|
||||
{
|
||||
}
|
||||
CK_TILE_DEVICE const index_t& operator()(index_t i_row, index_t i_col) const
|
||||
{
|
||||
return smem[i_row * row_stride + i_col];
|
||||
}
|
||||
CK_TILE_DEVICE index_t& operator()(index_t i_row, index_t i_col)
|
||||
{
|
||||
return smem[i_row * row_stride + i_col];
|
||||
}
|
||||
|
||||
// this is 1D or linear
|
||||
CK_TILE_DEVICE simple_smem_indexer(index_t* smem_) : smem(smem_), row_stride(0) {}
|
||||
CK_TILE_DEVICE const index_t& operator()(index_t idx) const { return smem[idx]; }
|
||||
CK_TILE_DEVICE index_t& operator()(index_t idx) { return smem[idx]; }
|
||||
};
|
||||
|
||||
CK_TILE_DEVICE void
|
||||
moe_align_block_size_kernel_ex(const IndexType* __restrict__ topk_id,
|
||||
const WeightType* __restrict__ weights,
|
||||
const IndexType* __restrict__ local_expert_mask,
|
||||
index_t* p_sorted_token_ids,
|
||||
WeightType* p_sorted_weights,
|
||||
index_t* p_sorted_expert_ids,
|
||||
index_t* p_total_tokens_post_pad,
|
||||
const index_t num_experts,
|
||||
const index_t tokens,
|
||||
const mdiv unit_size_mdiv,
|
||||
const mdiv topk_mdiv,
|
||||
const mdiv expert_mdiv,
|
||||
const index_t smem_rows,
|
||||
void* smem) const
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
const index_t wid = __builtin_amdgcn_readfirstlane(tid / warpSize);
|
||||
const index_t lid = __lane_id();
|
||||
constexpr index_t block_size = 256; // blockDim.x;
|
||||
const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor;
|
||||
const index_t topk = topk_mdiv.divisor;
|
||||
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
|
||||
|
||||
const index_t smem_cols = num_experts + 1;
|
||||
|
||||
simple_smem_indexer smem_cumsum{reinterpret_cast<index_t*>(smem) + 0};
|
||||
simple_smem_indexer smem_cumdup{reinterpret_cast<index_t*>(smem) + smem_cols};
|
||||
simple_smem_indexer smem_tokens{reinterpret_cast<index_t*>(smem) + 2 * smem_cols,
|
||||
smem_cols};
|
||||
|
||||
// #pragma unroll 8
|
||||
for(int i = tid; i < (sub_tokens * num_experts); i += block_size)
|
||||
{
|
||||
uint32_t curr_token_id, curr_expert_id;
|
||||
expert_mdiv.divmod(i, curr_token_id, curr_expert_id);
|
||||
smem_tokens(curr_token_id, curr_expert_id) = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
|
||||
{
|
||||
// NOTE: below for loop can't have barrier inside!!
|
||||
for(int i = tid; i < (sub_tokens * topk); i += block_size)
|
||||
{
|
||||
uint32_t curr_token_id, curr_topk_id;
|
||||
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
|
||||
int i_t = i_token + curr_token_id;
|
||||
|
||||
if(i_t < tokens)
|
||||
{
|
||||
int eid = topk_id[i_t * topk + curr_topk_id];
|
||||
|
||||
if constexpr(Problem::SubTokenOneShot)
|
||||
smem_tokens(curr_token_id, eid) = curr_topk_id + 1;
|
||||
else
|
||||
smem_tokens(curr_token_id, eid)++;
|
||||
}
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
}
|
||||
__syncthreads(); // make sure different i_token iteration not overlap by different wave
|
||||
}
|
||||
|
||||
// counting
|
||||
if(tid == 0)
|
||||
{
|
||||
smem_cumsum(0) = 0;
|
||||
// smem_cumdup(0) = 0;
|
||||
}
|
||||
|
||||
{
|
||||
constexpr int lane_group_sz = 8;
|
||||
int lane_group_id = tid / lane_group_sz;
|
||||
int lane_group_os = tid % lane_group_sz;
|
||||
constexpr int lane_group_nm = block_size / lane_group_sz;
|
||||
|
||||
for(int i_e = lane_group_id; i_e < num_experts; i_e += lane_group_nm)
|
||||
{
|
||||
index_t local_c[Problem::SubTokenTile];
|
||||
index_t cnt = 0;
|
||||
|
||||
for(int i = 0; i < sub_tokens; i += 8 * Problem::SubTokenTile)
|
||||
{
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(int j = 0; j < Problem::SubTokenTile; j++)
|
||||
{
|
||||
local_c[j] = smem_tokens(i + j * 8 + lane_group_os, i_e);
|
||||
if constexpr(Problem::SubTokenOneShot)
|
||||
{
|
||||
local_c[j] = local_c[j] != 0 ? 1 : 0;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll Problem::SubTokenTile
|
||||
for(int j = 0; j < Problem::SubTokenTile; j++)
|
||||
{
|
||||
cnt += wave_reduce(local_c[j], f_sum, number<8>{});
|
||||
}
|
||||
}
|
||||
if(lane_group_os == 0)
|
||||
smem_cumsum(i_e + 1) = cnt;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
smem_cumdup(0) = 0;
|
||||
for(int i_e = tid; i_e < num_experts; i_e += block_size)
|
||||
{
|
||||
// reuse this buffer
|
||||
smem_cumdup(i_e + 1) = local_expert_mask[i_e];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
{
|
||||
if(wid == 0)
|
||||
{
|
||||
// NOTE: under this block can never use __syncthreads!
|
||||
int i_e_ = 0;
|
||||
int local_cumsum_ = 0;
|
||||
for(; i_e_ < num_experts; i_e_ += warpSize)
|
||||
{
|
||||
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
|
||||
int local_cnt = smem_cumsum(i_e_ + lid + 1);
|
||||
int blocks_pers_expert =
|
||||
unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
|
||||
|
||||
int pre_cumsum_masking = [&]() {
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
return smem_cumdup(lid == 0 ? i_e_ : 0);
|
||||
else
|
||||
return 0; // not used
|
||||
}();
|
||||
int local_masking = [&]() {
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
return smem_cumdup(i_e_ + lid + 1);
|
||||
else
|
||||
return 0; // not used
|
||||
}();
|
||||
int padded_tokens_per_expert = [&]() {
|
||||
int x_ = [&]() {
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
{
|
||||
// if local_cnt is zero, blocks_pers_expert will be zero
|
||||
// this is what we want to achieve
|
||||
return blocks_pers_expert * unit_size_mdiv.divisor;
|
||||
}
|
||||
else
|
||||
{
|
||||
return max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
|
||||
}
|
||||
}();
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
return local_masking ? x_ : 0;
|
||||
}
|
||||
else
|
||||
return x_;
|
||||
}();
|
||||
|
||||
local_cumsum_ = padded_tokens_per_expert;
|
||||
local_cumsum_ += pre_cumsum_; // note pre_cumsum must be added after local
|
||||
// cumsum padded in case local cumsum is zero, but
|
||||
// pre_sumsum has value, which will result int
|
||||
// zero local cumsum(but we want at least padded)
|
||||
wave_cumsum<int, warpSize>(local_cumsum_);
|
||||
|
||||
if((i_e_ + lid) < num_experts)
|
||||
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
|
||||
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
local_masking += pre_cumsum_masking;
|
||||
wave_cumsum<int, warpSize>(local_masking);
|
||||
if((i_e_ + lid) < num_experts)
|
||||
smem_cumdup(i_e_ + lid + 1) = local_masking;
|
||||
}
|
||||
|
||||
// NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt()
|
||||
// for above write however __syncthreads will cause barrier with waves other
|
||||
// than 0(which is not we want)
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
}
|
||||
if((lid + i_e_ - warpSize) == (num_experts - 1))
|
||||
{
|
||||
*p_total_tokens_post_pad = local_cumsum_;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
for(int i_e = tid; i_e < num_experts; i_e += block_size)
|
||||
{
|
||||
int e_start = smem_cumsum(i_e);
|
||||
int e_end = smem_cumsum(i_e + 1);
|
||||
|
||||
int expert_id = [&]() {
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
// local expert id from cumsum
|
||||
return smem_cumdup(i_e);
|
||||
}
|
||||
else
|
||||
return i_e;
|
||||
}();
|
||||
|
||||
smem_cumdup(i_e) = e_start; // duplicate cumsum for later use
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
{
|
||||
if(e_start == e_end) // skip zero token expert
|
||||
continue;
|
||||
}
|
||||
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
if(local_expert_mask[i_e] == 0)
|
||||
continue;
|
||||
}
|
||||
|
||||
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
|
||||
{
|
||||
p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
|
||||
}
|
||||
}
|
||||
smem_cumdup(num_experts) = smem_cumsum(num_experts);
|
||||
|
||||
// fill the p_sorted_token_ids/p_sorted_weights
|
||||
for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
|
||||
{
|
||||
if constexpr(!Problem::SubTokenOneShot)
|
||||
{
|
||||
// clear every time
|
||||
for(int i = tid; i < (sub_tokens * num_experts); i += block_size)
|
||||
{
|
||||
uint32_t curr_token_id, curr_expert_id;
|
||||
expert_mdiv.divmod(i, curr_token_id, curr_expert_id);
|
||||
smem_tokens(curr_token_id, curr_expert_id) = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// load again
|
||||
for(int i = tid; i < (sub_tokens * topk); i += block_size)
|
||||
{
|
||||
uint32_t curr_token_id_, curr_topk_id_;
|
||||
topk_mdiv.divmod(i, curr_token_id_, curr_topk_id_);
|
||||
int curr_token_id = static_cast<int>(curr_token_id_);
|
||||
int curr_topk_id = static_cast<int>(curr_topk_id_);
|
||||
int i_t = i_token + curr_token_id;
|
||||
if(i_t < tokens)
|
||||
{
|
||||
int eid = topk_id[i_t * topk + curr_topk_id];
|
||||
smem_tokens(curr_token_id, eid) = curr_topk_id + 1; // at least 1
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
{
|
||||
constexpr int lane_group_sz = 8;
|
||||
int lane_group_id = tid / lane_group_sz;
|
||||
int lane_group_os = tid % lane_group_sz;
|
||||
constexpr int lane_group_nm = block_size / lane_group_sz;
|
||||
for(int eid = lane_group_id; eid < num_experts; eid += lane_group_nm)
|
||||
{
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
if(local_expert_mask[eid] == 0)
|
||||
continue;
|
||||
}
|
||||
int position = smem_cumsum(eid);
|
||||
for(int i_sub_token = lane_group_os; i_sub_token < sub_tokens;
|
||||
i_sub_token += lane_group_sz)
|
||||
{
|
||||
auto x = smem_tokens(i_sub_token, eid);
|
||||
|
||||
int local_cnt_cache = x != 0 ? 1 : 0;
|
||||
int local_cnt = local_cnt_cache;
|
||||
wave_cumsum<int, lane_group_sz>(local_cnt);
|
||||
if(x != 0)
|
||||
{
|
||||
// now x is topk value
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[position + local_cnt - 1] =
|
||||
MOE_SORTING_MOCK_ID(i_token + i_sub_token, x - 1);
|
||||
#else
|
||||
p_sorted_token_ids[position + local_cnt - 1] = i_token + i_sub_token;
|
||||
#endif
|
||||
p_sorted_weights[position + local_cnt - 1] =
|
||||
weights[(i_token + i_sub_token) * topk + x - 1];
|
||||
}
|
||||
|
||||
int remote_cnt = __builtin_amdgcn_ds_bpermute(
|
||||
(lane_group_sz * (lane_group_id + 1) - 1) << 2, local_cnt);
|
||||
|
||||
position += remote_cnt;
|
||||
}
|
||||
smem_cumsum(eid) = position;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// add the skip number
|
||||
for(int eid = tid; eid < num_experts; eid += block_size)
|
||||
{
|
||||
int e_start = smem_cumsum(eid);
|
||||
int e_end = smem_cumdup(eid + 1);
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
{
|
||||
if(e_start == e_end) // skip zero token expert
|
||||
continue;
|
||||
}
|
||||
while(e_start < e_end)
|
||||
{
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[e_start] = MOE_SORTING_MOCK_ID(tokens, topk);
|
||||
#else
|
||||
p_sorted_token_ids[e_start] = tokens;
|
||||
#endif
|
||||
p_sorted_weights[e_start] = static_cast<WeightType>(0.0);
|
||||
e_start++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -456,6 +1012,24 @@ struct MoeSortingKernel
|
||||
}
|
||||
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
|
||||
extern __shared__ char smem[];
|
||||
#if MOE_SORTING_USE_EX_KERNEL
|
||||
(void)numel;
|
||||
return moe_align_block_size_kernel_ex(
|
||||
static_cast<const IndexType*>(kargs.p_topk_ids),
|
||||
static_cast<const WeightType*>(kargs.p_weights),
|
||||
static_cast<const IndexType*>(kargs.p_local_expert_mask),
|
||||
static_cast<IndexType*>(kargs.p_sorted_token_ids),
|
||||
static_cast<WeightType*>(kargs.p_sorted_weights),
|
||||
static_cast<IndexType*>(kargs.p_sorted_expert_ids),
|
||||
static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
|
||||
kargs.num_experts,
|
||||
kargs.tokens,
|
||||
kargs.unit_size_mdiv,
|
||||
kargs.topk_mdiv,
|
||||
kargs.expert_mdiv,
|
||||
kargs.smem_rows,
|
||||
smem);
|
||||
#else
|
||||
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
|
||||
static_cast<const WeightType*>(kargs.p_weights),
|
||||
static_cast<IndexType*>(kargs.p_sorted_token_ids),
|
||||
@@ -468,6 +1042,7 @@ struct MoeSortingKernel
|
||||
kargs.unit_size_mdiv,
|
||||
kargs.topk_mdiv,
|
||||
smem);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
52
include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp
Normal file
52
include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp
Normal file
@@ -0,0 +1,52 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename IndexType_,
|
||||
typename WeightType_,
|
||||
index_t InternalLoadUnroll_,
|
||||
index_t ExpertTile_ = 0>
|
||||
struct MoeSortingProblem
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using WeightType = remove_cvref_t<WeightType_>;
|
||||
using IndexType = remove_cvref_t<IndexType_>;
|
||||
|
||||
static constexpr index_t WarpSize = get_warp_size();
|
||||
static constexpr index_t WarpsPerBlock = 1;
|
||||
static constexpr index_t InternalLoadUnroll =
|
||||
InternalLoadUnroll_; // TODO: need better design(like tile size)
|
||||
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
|
||||
};
|
||||
|
||||
template <typename IndexType_,
|
||||
typename WeightType_,
|
||||
index_t SubTokenTile_, // 1,2,4,8, or 0 in the future
|
||||
bool SubTokenOneShot_, // if we only loop over once or not
|
||||
bool LocalExpertMasking_, // used in EP case
|
||||
bool SkipExpertsWithZeroTokens_ = true,
|
||||
index_t ExpertTile_ = 0>
|
||||
struct MoeSortingProblemEx
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using WeightType = remove_cvref_t<WeightType_>;
|
||||
using IndexType = remove_cvref_t<IndexType_>;
|
||||
|
||||
static constexpr index_t WarpSize = get_warp_size();
|
||||
static constexpr index_t WarpsPerBlock = 1;
|
||||
static constexpr index_t SubTokenTile = SubTokenTile_;
|
||||
static constexpr bool SubTokenOneShot = SubTokenOneShot_;
|
||||
static constexpr bool LocalExpertMasking = LocalExpertMasking_;
|
||||
static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_;
|
||||
static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || SubTokenTile == 8);
|
||||
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,28 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename IndexType_,
|
||||
typename WeightType_,
|
||||
index_t InternalLoadUnroll_,
|
||||
index_t ExpertTile_ = 0>
|
||||
struct MoeSortingProblem
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using WeightType = remove_cvref_t<WeightType_>;
|
||||
using IndexType = remove_cvref_t<IndexType_>;
|
||||
|
||||
static constexpr index_t WarpSize = get_warp_size();
|
||||
static constexpr index_t WarpsPerBlock = 1;
|
||||
static constexpr index_t InternalLoadUnroll =
|
||||
InternalLoadUnroll_; // TODO: need better design(like tile size)
|
||||
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -46,3 +46,4 @@
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -57,6 +59,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
using BLayout = typename Base::BLayout;
|
||||
using CLayout = typename Base::CLayout;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
|
||||
return concat('_', "gemm_batched", gemm_prec_str<ADataType, BDataType>,
|
||||
concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
struct BatchedGemmKernelArgs : GemmKernelArgs
|
||||
{
|
||||
index_t batch_stride_A;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -75,6 +76,13 @@ struct GemmKernel
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
|
||||
@@ -64,6 +64,18 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
}
|
||||
};
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
|
||||
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>,
|
||||
concat('x', P_::kMPerBlock, P_::kNPerBlock, P_::kKPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
__host__ static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
|
||||
-> std::size_t
|
||||
{
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -81,6 +82,15 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
using Base::PrefetchStages;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AgBgCrCompV3", BlockSize,
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -128,6 +129,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AgBgCrMe",
|
||||
concat('x', MPerBlock, NPerBlock, KPerBlock),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
using Base::PrefetchStages;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -39,6 +40,18 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr index_t kLdsAlignmentInBytes = 16;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AGmemBGmemCRegV1",
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -75,8 +88,9 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
|
||||
16;
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
|
||||
kLdsAlignmentInBytes) *
|
||||
kLdsAlignmentInBytes;
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -25,6 +26,13 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AGmemBGmemCRegV2",
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, kBlockSize));
|
||||
// clang-format on
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -35,9 +36,19 @@ struct GemmPipelineProblemBase
|
||||
static constexpr bool kPadN = Traits::kPadN;
|
||||
static constexpr bool kPadK = Traits::kPadK;
|
||||
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
||||
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm_problem",
|
||||
concat('x', VectorLoadSize, kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -19,6 +20,16 @@ struct TileGemmShape
|
||||
static constexpr index_t kM = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kN = BlockTile::at(number<1>{});
|
||||
static constexpr index_t kK = BlockTile::at(number<2>{});
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "tile_gemm_shape",
|
||||
concat('x', kM, kN, kK, NumWarps),
|
||||
concat('x', BlockWarps::at(number<0>{}), BlockWarps::at(number<1>{}), BlockWarps::at(number<2>{})),
|
||||
concat('x', (WarpTile::at(number<0>{})), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})));
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -8,3 +8,4 @@
|
||||
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -11,3 +11,4 @@
|
||||
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -8,3 +8,4 @@
|
||||
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -7,3 +7,4 @@
|
||||
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -9,3 +9,4 @@
|
||||
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -11,3 +11,4 @@
|
||||
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -11,3 +11,4 @@
|
||||
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -7,3 +7,4 @@
|
||||
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -7,3 +7,4 @@
|
||||
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -9,3 +9,4 @@
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
Reference in New Issue
Block a user