mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK-Tile] Refactor base pipeline usage (#3251)
* initial poc * factor out common parts in operator() * cv4 * rest of the universal gemm pipelines * fix test * remove boilerplate from tile engine * fix example * fix example * format * fix tests build for gemm * remove base pipeline codegen from gemm instance builder * unify v3 logic with the rest of universal gemm pipelines * fix build for multi abd test * fix test gemm multi d * fix build for weight preshuffle * fix grouped gemm test * fix grouped gemm multi d test * fix grouped gemm preshuffle * fix grouped gemm example except for quant * fix gemm preshuffle * fix splitk 2 stage example * fix batched gemm example * fix multid example * fix multiabd example * fix batched gemm test * fixup * fix examples build * fix grouped gemm test build * fix smoke builder
This commit is contained in:
@@ -46,14 +46,6 @@ struct SplitKTwoStageInvoker
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
@@ -67,41 +59,22 @@ struct SplitKTwoStageInvoker
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
using WorkspaceType = ck_tile::remove_cvref_t<typename GemmConfig::WorkspaceType>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using WorkspaceType = ck_tile::remove_cvref_t<typename GemmConfig::WorkspaceType>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -230,7 +203,7 @@ struct SplitKTwoStageInvoker
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
@@ -244,21 +217,15 @@ struct SplitKTwoStageInvoker
|
||||
ck_tile::make_tuple(args.N, 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<CDataType*>(c_ptr)));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
}
|
||||
};
|
||||
|
||||
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -133,14 +133,6 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
@@ -154,19 +146,6 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
// Create base GEMM arguments pointing to workspace instead of final output
|
||||
// The workspace will store partial results from each K-split
|
||||
ck_tile::GemmHostArgs base_args(args.a_ptr,
|
||||
@@ -179,23 +158,18 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
|
||||
args.stride_A,
|
||||
args.stride_B,
|
||||
args.stride_E);
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&]() {
|
||||
// use SET operation since each K-split writes to separate memory
|
||||
constexpr auto memory_operation = ck_tile::memory_operation_enum::set;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
@@ -276,29 +250,20 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
return ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
// For workspace mode, always use SET operation since each K-split writes to separate memory
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
};
|
||||
|
||||
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return Run();
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -33,14 +33,6 @@ struct WeightPreshuffleInvoker
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
@@ -54,39 +46,20 @@ struct WeightPreshuffleInvoker
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -139,6 +112,7 @@ struct WeightPreshuffleInvoker
|
||||
<< "}" << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
float ave_time = 0.f;
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
@@ -183,21 +157,14 @@ struct WeightPreshuffleInvoker
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("split-k is not supported yet!");
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("split-k is not supported yet!");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -63,14 +63,17 @@ void permute_tensor_b(Tensor& tensor)
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::Scheduler,
|
||||
true,
|
||||
ck_tile::TailNumber::Full>;
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::Scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ADataType,
|
||||
true>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
@@ -34,14 +34,6 @@ struct UniversalInvoker
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
@@ -55,39 +47,22 @@ struct UniversalInvoker
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -173,25 +148,19 @@ struct UniversalInvoker
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
}
|
||||
};
|
||||
|
||||
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -59,7 +59,6 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
@@ -68,105 +67,74 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
float ave_time{0};
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
else
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
@@ -42,12 +42,6 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
@@ -56,109 +50,83 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
BLayout,
|
||||
CLayout,
|
||||
GemmConfig::TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
float ave_time{0};
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName()
|
||||
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
|
||||
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
|
||||
<< blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
else
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
|
||||
@@ -42,12 +42,6 @@ float grouped_gemm_multi_d(const std::vector<grouped_gemm_multi_d_kargs>& gemm_d
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
@@ -56,39 +50,20 @@ float grouped_gemm_multi_d(const std::vector<grouped_gemm_multi_d_kargs>& gemm_d
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -131,39 +106,27 @@ float grouped_gemm_multi_d(const std::vector<grouped_gemm_multi_d_kargs>& gemm_d
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
|
||||
return ave_time;
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
|
||||
@@ -41,13 +41,6 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
@@ -60,111 +53,82 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
GemmConfig::Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop =
|
||||
// if preshuffle == true then num_loop is recalculated for each group in the kernel code
|
||||
TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName()
|
||||
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
|
||||
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
|
||||
<< blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
else
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
};
|
||||
|
||||
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
|
||||
@@ -61,8 +61,6 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
@@ -71,101 +69,72 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
float ave_time{0};
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", "
|
||||
<< grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", "
|
||||
<< blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
else
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
|
||||
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y
|
||||
<< ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_gemm_multi_d_fp16_example.inc"
|
||||
|
||||
@@ -57,43 +57,9 @@ struct GroupedConvolutionBackwardDataInvoker
|
||||
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdData<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t gemm_k =
|
||||
args.K_ * std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
@@ -103,8 +69,6 @@ struct GroupedConvolutionBackwardDataInvoker
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
InDataType,
|
||||
@@ -170,26 +134,19 @@ struct GroupedConvolutionBackwardDataInvoker
|
||||
kargs.in_ptr, 0, args.template GetInputByte<InDataType>(), s.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -57,43 +57,9 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t gemm_k =
|
||||
args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
|
||||
args.output_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
@@ -103,8 +69,6 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
@@ -176,26 +140,19 @@ struct GroupedConvolutionBackwardWeightInvoker
|
||||
}
|
||||
};
|
||||
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -60,42 +60,9 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
OutDataType,
|
||||
InDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t gemm_k =
|
||||
args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
|
||||
args.output_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
@@ -105,8 +72,6 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
WeiDataType,
|
||||
@@ -209,7 +174,6 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
@@ -228,7 +192,7 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
s.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs),
|
||||
@@ -242,22 +206,15 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker
|
||||
ck_tile::make_tuple(shape[1], 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<WeiDataType*>(c_ptr)));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -65,148 +65,96 @@ struct GroupedConvolutionForwardInvoker
|
||||
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
|
||||
GroupedConvTraitsType::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsFwd<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t gemm_k =
|
||||
args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
|
||||
// Split-K parameters
|
||||
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
|
||||
// =====================================================================
|
||||
// Regular Convolution: Simple, no split-image
|
||||
// =====================================================================
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeA,
|
||||
GroupedConvTraitsType::VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
|
||||
CDElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
ConvConfig::M_Warp,
|
||||
ConvConfig::N_Warp,
|
||||
ConvConfig::M_Warp_Tile,
|
||||
ConvConfig::N_Warp_Tile,
|
||||
ConvConfig::K_Warp_Tile,
|
||||
GroupedConvTraitsType::FixedGemmParams::TransposeC,
|
||||
memory_operation,
|
||||
ConvConfig::NumWaveGroups,
|
||||
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
// =====================================================================
|
||||
// Split-K lambda
|
||||
// =====================================================================
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
Run.template operator()(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
else
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
Run.template operator()(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<ConvConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
// =====================================================================
|
||||
// Regular Convolution Example: ALWAYS uses regular path (Kernel<false>)
|
||||
// Split-K dispatch
|
||||
// =====================================================================
|
||||
// This example demonstrates regular convolution without split-image.
|
||||
// For large images that don't fit in memory, use
|
||||
// grouped_convolution_forward_split_image.cpp
|
||||
|
||||
// Launch kernel using regular path (no split-image)
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(MemoryOpAtomicAdd{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -72,36 +72,6 @@ struct GroupedConvolutionForwardInvoker
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::Persistent,
|
||||
ConvConfig::NumWaveGroups>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsTypeDefault::template GroupedConvImplicitGemmTraitsFwd<
|
||||
ConvConfig::NumWaveGroups>,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
GroupedConvTraitsTypeDefault::FixedGemmParams::FixedVectorSize,
|
||||
GroupedConvTraitsTypeDefault::VectorSizeA,
|
||||
GroupedConvTraitsTypeDefault::VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
ConvConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t gemm_k =
|
||||
args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
|
||||
// Split-K parameters
|
||||
const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
using TransformType =
|
||||
ck_tile::TransformConvFwdToGemm<NDimSpatial,
|
||||
ck_tile::ConvolutionSpecialization::Default,
|
||||
@@ -238,16 +208,12 @@ struct GroupedConvolutionForwardInvoker
|
||||
}
|
||||
}
|
||||
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
|
||||
// =====================================================================
|
||||
// Kernel launch lambda: Uses EnableSplitImage based on layout support
|
||||
// =====================================================================
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_,
|
||||
const auto enable_split_image_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ConvConfig::Scheduler;
|
||||
const auto Run = [&](const auto memory_operation_, const auto enable_split_image_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr bool EnableSplitImage = enable_split_image_.value;
|
||||
|
||||
@@ -262,8 +228,6 @@ struct GroupedConvolutionForwardInvoker
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
@@ -368,33 +332,17 @@ struct GroupedConvolutionForwardInvoker
|
||||
// =====================================================================
|
||||
if(use_split_image)
|
||||
{
|
||||
const auto RunSplitImage = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
return Run(
|
||||
has_hot_loop_, tail_number_, MemoryOpSet{}, ck_tile::bool_constant<true>{});
|
||||
else
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
MemoryOpAtomicAdd{},
|
||||
ck_tile::bool_constant<true>{});
|
||||
};
|
||||
return BaseGemmPipeline::TailHandler(RunSplitImage, has_hot_loop, tail_num);
|
||||
if(args.k_batch == 1)
|
||||
return Run(MemoryOpSet{}, ck_tile::bool_constant<true>{});
|
||||
else
|
||||
return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant<true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto RunRegular = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
MemoryOpSet{},
|
||||
ck_tile::bool_constant<false>{});
|
||||
else
|
||||
return Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
MemoryOpAtomicAdd{},
|
||||
ck_tile::bool_constant<false>{});
|
||||
};
|
||||
return BaseGemmPipeline::TailHandler(RunRegular, has_hot_loop, tail_num);
|
||||
if(args.k_batch == 1)
|
||||
return Run(MemoryOpSet{}, ck_tile::bool_constant<false>{});
|
||||
else
|
||||
return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant<false>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -63,8 +63,6 @@ auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_conf
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, AsLayout, BsLayout, ELayout>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
@@ -73,103 +71,74 @@ auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_conf
|
||||
BsLayout,
|
||||
ELayout,
|
||||
TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<AsDataType, BsDataType, AccDataType, GemmShape, Traits>;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<AsDataType,
|
||||
BsDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
AElementWise,
|
||||
BElementWise>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
float ave_time{0};
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<AsDataType,
|
||||
BsDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
AElementWise,
|
||||
BElementWise>;
|
||||
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", "
|
||||
<< grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", "
|
||||
<< blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
else
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y
|
||||
<< ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y
|
||||
<< ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_gemm_multi_abd_fp16_example.inc"
|
||||
|
||||
@@ -90,24 +90,9 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs<DsDataT
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
|
||||
ck_tile::index_t K_total = 1;
|
||||
for(ck_tile::index_t i = NumDimG + NumDimM; i < NumDimG + NumDimM + NumDimK; ++i)
|
||||
{
|
||||
K_total *= args.A_dims[i];
|
||||
}
|
||||
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_total);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
const auto Run = [&]() {
|
||||
constexpr auto memory_operation =
|
||||
ck_tile::memory_operation_enum::set; // Always set (no atomic_add)
|
||||
|
||||
@@ -116,9 +101,7 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs<DsDataT
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
|
||||
@@ -166,14 +149,10 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs<DsDataT
|
||||
|
||||
auto kernel = ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs);
|
||||
|
||||
ave_time = ck_tile::launch_kernel(s, kernel);
|
||||
|
||||
return ave_time;
|
||||
return ck_tile::launch_kernel(s, kernel);
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
return Run();
|
||||
}
|
||||
|
||||
#define HANDLE_CASE(G, M, N, K) \
|
||||
|
||||
Reference in New Issue
Block a user