mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
try encapsulating the kernel instantiation guts
This commit is contained in:
@@ -5,6 +5,117 @@
|
||||
#include <functional>
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
namespace ck_tile::experimental::builder {
|
||||
template <class AlgorithmMetadata, class InputMetadata>
|
||||
struct UniversalFactory
|
||||
{
|
||||
private:
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<AlgorithmMetadata::M_Tile::value,
|
||||
AlgorithmMetadata::N_Tile::value,
|
||||
AlgorithmMetadata::K_Tile::value>,
|
||||
ck_tile::sequence<AlgorithmMetadata::M_Warp::value,
|
||||
AlgorithmMetadata::N_Warp::value,
|
||||
AlgorithmMetadata::K_Warp::value>,
|
||||
ck_tile::sequence<AlgorithmMetadata::M_Warp_Tile::value,
|
||||
AlgorithmMetadata::N_Warp_Tile::value,
|
||||
AlgorithmMetadata::K_Warp_Tile::value>,
|
||||
AlgorithmMetadata::PermuteA::value,
|
||||
AlgorithmMetadata::PermuteB::value>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
AlgorithmMetadata::TileParitionerGroupNum::value,
|
||||
AlgorithmMetadata::TileParitionerM01::value>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<AlgorithmMetadata::kPadM::value,
|
||||
AlgorithmMetadata::kPadN::value,
|
||||
AlgorithmMetadata::kPadK::value,
|
||||
typename InputMetadata::InputALayout,
|
||||
typename InputMetadata::InputBLayout,
|
||||
typename InputMetadata::InputELayout,
|
||||
AlgorithmMetadata::NumWaveGroups::value>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<AlgorithmMetadata::kPadM::value,
|
||||
AlgorithmMetadata::kPadN::value,
|
||||
AlgorithmMetadata::kPadK::value,
|
||||
AlgorithmMetadata::DoubleSmemBuffer::value,
|
||||
typename InputMetadata::InputALayout,
|
||||
typename InputMetadata::InputBLayout,
|
||||
typename InputMetadata::InputELayout,
|
||||
AlgorithmMetadata::TransposeC::value,
|
||||
AlgorithmMetadata::UseStructuredSparsity::value,
|
||||
AlgorithmMetadata::KPersistent::value,
|
||||
AlgorithmMetadata::NumWaveGroups::value,
|
||||
AlgorithmMetadata::Preshuffle::value>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<typename InputMetadata::InputADataType,
|
||||
typename InputMetadata::InputBDataType,
|
||||
typename InputMetadata::InputAccDataType,
|
||||
GemmShape,
|
||||
Traits>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
AlgorithmMetadata::Pipeline::value>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<typename InputMetadata::InputADataType,
|
||||
typename InputMetadata::InputBDataType,
|
||||
typename InputMetadata::InputAccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
AlgorithmMetadata::Scheduler::value,
|
||||
AlgorithmMetadata::HasHotLoop::value,
|
||||
AlgorithmMetadata::TailNum::value>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
AlgorithmMetadata::Pipeline::value>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename InputMetadata::InputADataType,
|
||||
typename InputMetadata::InputBDataType,
|
||||
typename InputMetadata::InputDsDataType,
|
||||
typename InputMetadata::InputAccDataType,
|
||||
typename InputMetadata::InputCDataType,
|
||||
typename InputMetadata::InputDsLayout,
|
||||
typename InputMetadata::InputELayout,
|
||||
typename InputMetadata::InputCDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
AlgorithmMetadata::M_Warp::value,
|
||||
AlgorithmMetadata::N_Warp::value,
|
||||
AlgorithmMetadata::M_Warp_Tile::value,
|
||||
AlgorithmMetadata::N_Warp_Tile::value,
|
||||
AlgorithmMetadata::K_Warp_Tile::value,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
AlgorithmMetadata::MemoryOperation::value,
|
||||
AlgorithmMetadata::NumWaveGroups::value>>;
|
||||
|
||||
public:
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
CK_TILE_HOST static constexpr auto make_kernel(const ck_tile::GemmHostArgs& args)
|
||||
{
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
// NB: do we really need the stream to be launched here?
|
||||
const dim3 grids = AlgorithmMetadata::KPersistent::value
|
||||
? Kernel::MaxOccupancyGridSize(ck_tile::stream_config{})
|
||||
: Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
return ck_tile::make_kernel<AlgorithmMetadata::kBlockPerCu::value>(
|
||||
Kernel{}, grids, blocks, 0, kargs);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile::experimental::builder
|
||||
|
||||
struct UniversalInvoker
|
||||
{
|
||||
template <typename GemmConfig,
|
||||
@@ -22,162 +133,165 @@ struct UniversalInvoker
|
||||
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
// const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
// const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain *
|
||||
// GemmConfig::K_Tile; const ck_tile::index_t num_loop =
|
||||
// TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop =
|
||||
// BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num =
|
||||
// BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
// const ck_tile::index_t num_loop = 64;
|
||||
const bool has_hot_loop = true;
|
||||
const ck_tile::TailNumber tail_num = ck_tile::TailNumber::Full;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto kernel_launch_visitor = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s)
|
||||
: Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
struct Algo
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
// can't do `static constexpr` in local structs
|
||||
using M_Tile =
|
||||
ck_tile::integral_constant<decltype(GemmConfig::M_Tile), GemmConfig::M_Tile>;
|
||||
using N_Tile =
|
||||
ck_tile::integral_constant<decltype(GemmConfig::N_Tile), GemmConfig::N_Tile>;
|
||||
using K_Tile =
|
||||
ck_tile::integral_constant<decltype(GemmConfig::K_Tile), GemmConfig::K_Tile>;
|
||||
using M_Warp =
|
||||
ck_tile::integral_constant<decltype(GemmConfig::M_Warp), GemmConfig::M_Warp>;
|
||||
using N_Warp =
|
||||
ck_tile::integral_constant<decltype(GemmConfig::N_Warp), GemmConfig::N_Warp>;
|
||||
using K_Warp =
|
||||
ck_tile::integral_constant<decltype(GemmConfig::K_Warp), GemmConfig::K_Warp>;
|
||||
using M_Warp_Tile = ck_tile::integral_constant<decltype(GemmConfig::M_Warp_Tile),
|
||||
GemmConfig::M_Warp_Tile>;
|
||||
using N_Warp_Tile = ck_tile::integral_constant<decltype(GemmConfig::N_Warp_Tile),
|
||||
GemmConfig::N_Warp_Tile>;
|
||||
using K_Warp_Tile = ck_tile::integral_constant<decltype(GemmConfig::K_Warp_Tile),
|
||||
GemmConfig::K_Warp_Tile>;
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
using kPadM =
|
||||
ck_tile::integral_constant<decltype(GemmConfig::kPadM), GemmConfig::kPadM>;
|
||||
using kPadN =
|
||||
ck_tile::integral_constant<decltype(GemmConfig::kPadN), GemmConfig::kPadN>;
|
||||
using kPadK =
|
||||
ck_tile::integral_constant<decltype(GemmConfig::kPadK), GemmConfig::kPadK>;
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
using PermuteA = ck_tile::integral_constant<decltype(GemmConfig::PermuteA),
|
||||
GemmConfig::PermuteA>;
|
||||
using PermuteB = ck_tile::integral_constant<decltype(GemmConfig::PermuteB),
|
||||
GemmConfig::PermuteB>;
|
||||
using UseStructuredSparsity =
|
||||
ck_tile::integral_constant<decltype(GemmConfig::UseStructuredSparsity),
|
||||
GemmConfig::UseStructuredSparsity>;
|
||||
using KPersistent = ck_tile::integral_constant<decltype(Persistent), Persistent>;
|
||||
using Preshuffle = ck_tile::integral_constant<decltype(GemmConfig::Preshuffle),
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
using NumWaveGroups =
|
||||
ck_tile::integral_constant<decltype(GemmConfig::NumWaveGroups),
|
||||
GemmConfig::NumWaveGroups>;
|
||||
using DoubleSmemBuffer =
|
||||
ck_tile::integral_constant<decltype(GemmConfig::DoubleSmemBuffer),
|
||||
GemmConfig::DoubleSmemBuffer>;
|
||||
using TransposeC = ck_tile::integral_constant<decltype(GemmConfig::TransposeC),
|
||||
GemmConfig::TransposeC>;
|
||||
|
||||
using HasHotLoop = decltype(has_hot_loop_);
|
||||
using MemoryOperation = decltype(memory_operation_);
|
||||
using TailNum = decltype(tail_number_);
|
||||
|
||||
using Scheduler = ck_tile::integral_constant<decltype(GemmConfig::Scheduler),
|
||||
GemmConfig::Scheduler>;
|
||||
using TileParitionerGroupNum =
|
||||
ck_tile::integral_constant<decltype(GemmConfig::TileParitionerGroupNum),
|
||||
GemmConfig::TileParitionerGroupNum>;
|
||||
using TileParitionerM01 =
|
||||
ck_tile::integral_constant<decltype(GemmConfig::TileParitionerM01),
|
||||
GemmConfig::TileParitionerM01>;
|
||||
using Pipeline = ck_tile::integral_constant<decltype(GemmConfig::Pipeline),
|
||||
GemmConfig::Pipeline>;
|
||||
|
||||
using kBlockPerCu = ck_tile::integral_constant<decltype(GemmConfig::kBlockPerCu),
|
||||
GemmConfig::kBlockPerCu>;
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
struct Inp
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
using InputADataType = ADataType;
|
||||
using InputBDataType = BDataType;
|
||||
using InputDsDataType = DsDataType;
|
||||
using InputCDataType = CDataType;
|
||||
using InputAccDataType = AccDataType;
|
||||
using InputALayout = ALayout;
|
||||
using InputBLayout = BLayout;
|
||||
using InputDsLayout = DsLayout;
|
||||
using InputELayout = ELayout;
|
||||
using InputCDEElementWise = CDEElementWise;
|
||||
};
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
// if(s.log_level_ > 0)
|
||||
// {
|
||||
// std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
// << "shape: " << GemmShape::GetName() << '\n'
|
||||
// << "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
// << "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
// << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
// << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
// << "}" << std::endl;
|
||||
// }
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
// std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
// std::function<void()> preprocess;
|
||||
|
||||
rotating_mem_ptr =
|
||||
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0],
|
||||
kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
// auto clear_gemm_output = [&]() {
|
||||
// if(args.k_batch > 1)
|
||||
// hipGetErrorString(hipMemsetAsync(
|
||||
// args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
// };
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
// if(s.flush_cache_)
|
||||
// {
|
||||
// std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
// ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
// args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
// ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
// args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
// auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
// auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
// rotating_mem_ptr =
|
||||
// std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
// kargs.as_ptr[0],
|
||||
// kargs.bs_ptr[0],
|
||||
// s.rotating_count_,
|
||||
// size_a_buffer,
|
||||
// size_b_buffer);
|
||||
// rotating_mem_ptr->Print();
|
||||
|
||||
// preprocess = [&]() {
|
||||
// ck_tile::flush_icache();
|
||||
// rotating_mem_ptr->Next();
|
||||
// clear_gemm_output();
|
||||
// };
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// preprocess = clear_gemm_output;
|
||||
// }
|
||||
|
||||
// ave_time = ck_tile::launch_kernel_time_mask(
|
||||
// s,
|
||||
// preprocess,
|
||||
// ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0,
|
||||
// kargs));
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::experimental::builder::UniversalFactory<Algo, Inp>::make_kernel(args));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user