mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 20:27:42 +00:00
Change if-else statements into switch in conv factory.
This commit is contained in:
@@ -290,49 +290,44 @@ struct BlockGemmSpec
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr BlockGemmSpec SetBlockGemm()
|
||||
consteval BlockGemmSpec SetBlockGemm()
|
||||
{
|
||||
constexpr auto& BG = ALGORITHM.block_gemm;
|
||||
|
||||
ck::BlockGemmPipelineScheduler scheduler;
|
||||
ck::BlockGemmPipelineVersion version;
|
||||
|
||||
if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE)
|
||||
switch(BG.scheduler)
|
||||
{
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
}
|
||||
else if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE)
|
||||
{
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown BlockGemmPipelineScheduler");
|
||||
case BlockGemmPipelineScheduler::INTRAWAVE:
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
break;
|
||||
case BlockGemmPipelineScheduler::INTERWAVE:
|
||||
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
|
||||
break;
|
||||
default:
|
||||
throw "Unknown BlockGemmPipelineScheduler";
|
||||
}
|
||||
|
||||
if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V1)
|
||||
switch(BG.pipeline_version)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V2)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V3)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v3;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V4)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V5)
|
||||
{
|
||||
version = ck::BlockGemmPipelineVersion::v5;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown BlockGemmPipelineVersion");
|
||||
case BlockGemmPipelineVersion::V1:
|
||||
version = ck::BlockGemmPipelineVersion::v1;
|
||||
break;
|
||||
case BlockGemmPipelineVersion::V2:
|
||||
version = ck::BlockGemmPipelineVersion::v2;
|
||||
break;
|
||||
case BlockGemmPipelineVersion::V3:
|
||||
version = ck::BlockGemmPipelineVersion::v3;
|
||||
break;
|
||||
case BlockGemmPipelineVersion::V4:
|
||||
version = ck::BlockGemmPipelineVersion::v4;
|
||||
break;
|
||||
case BlockGemmPipelineVersion::V5:
|
||||
version = ck::BlockGemmPipelineVersion::v5;
|
||||
break;
|
||||
default:
|
||||
throw "Unknown BlockGemmPipelineVersion";
|
||||
}
|
||||
|
||||
return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler};
|
||||
@@ -441,18 +436,15 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::LoopScheduler SetLoopScheduler()
|
||||
{
|
||||
constexpr auto loop_scheduler = ALGORITHM.loop_scheduler;
|
||||
|
||||
if constexpr(loop_scheduler == LoopScheduler::DEFAULT)
|
||||
using ck_loop_sched = ck::LoopScheduler;
|
||||
switch(loop_scheduler)
|
||||
{
|
||||
return ck::LoopScheduler::Default;
|
||||
}
|
||||
else if constexpr(loop_scheduler == LoopScheduler::INTERWAVE)
|
||||
{
|
||||
return ck::LoopScheduler::Interwave;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown LoopScheduler");
|
||||
case LoopScheduler::DEFAULT:
|
||||
return ck_loop_sched::Default;
|
||||
case LoopScheduler::INTERWAVE:
|
||||
return ck_loop_sched::Interwave;
|
||||
default:
|
||||
throw "Unknown LoopScheduler";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -460,29 +452,21 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
|
||||
if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V1)
|
||||
using ck_pipeline = ck::PipelineVersion;
|
||||
switch(pipeline_version)
|
||||
{
|
||||
return ck::PipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V2)
|
||||
{
|
||||
return ck::PipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V3)
|
||||
{
|
||||
static_assert(false, "V3 is used only for stream-K.");
|
||||
}
|
||||
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V4)
|
||||
{
|
||||
return ck::PipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::WEIGHT_ONLY)
|
||||
{
|
||||
return ck::PipelineVersion::weight_only;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown GridwiseGemmPipelineVersion");
|
||||
case GridwiseGemmPipelineVersion::V1:
|
||||
return ck_pipeline::v1;
|
||||
case GridwiseGemmPipelineVersion::V2:
|
||||
return ck_pipeline::v2;
|
||||
case GridwiseGemmPipelineVersion::V4:
|
||||
return ck_pipeline::v4;
|
||||
case GridwiseGemmPipelineVersion::WEIGHT_ONLY:
|
||||
return ck_pipeline::weight_only;
|
||||
case GridwiseGemmPipelineVersion::V3:
|
||||
throw "GridwiseGemmPipelineVersion::V3 is used only for stream-K.";
|
||||
default:
|
||||
throw "Unknown GridwiseGemmPipelineVersion";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -490,74 +474,44 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization()
|
||||
{
|
||||
constexpr auto gemm_spec = ALGORITHM.gemm_specialization;
|
||||
using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
if constexpr(gemm_spec == GemmSpecialization::Default)
|
||||
switch(gemm_spec)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::NPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::NPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::KPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::KPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MNPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MKPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MKPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::NKPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::NKPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MNKPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::OPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::OPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::NOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::NOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::KOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::KOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MNOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MKOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MKOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::NKOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::NKOPadding;
|
||||
}
|
||||
else if constexpr(gemm_spec == GemmSpecialization::MNKOPadding)
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown GemmSpecialization");
|
||||
case GemmSpecialization::Default:
|
||||
return ck_gemm_spec::Default;
|
||||
case GemmSpecialization::MPadding:
|
||||
return ck_gemm_spec::MPadding;
|
||||
case GemmSpecialization::NPadding:
|
||||
return ck_gemm_spec::NPadding;
|
||||
case GemmSpecialization::KPadding:
|
||||
return ck_gemm_spec::KPadding;
|
||||
case GemmSpecialization::MNPadding:
|
||||
return ck_gemm_spec::MNPadding;
|
||||
case GemmSpecialization::MKPadding:
|
||||
return ck_gemm_spec::MKPadding;
|
||||
case GemmSpecialization::NKPadding:
|
||||
return ck_gemm_spec::NKPadding;
|
||||
case GemmSpecialization::MNKPadding:
|
||||
return ck_gemm_spec::MNKPadding;
|
||||
case GemmSpecialization::OPadding:
|
||||
return ck_gemm_spec::OPadding;
|
||||
case GemmSpecialization::MOPadding:
|
||||
return ck_gemm_spec::MOPadding;
|
||||
case GemmSpecialization::NOPadding:
|
||||
return ck_gemm_spec::NOPadding;
|
||||
case GemmSpecialization::KOPadding:
|
||||
return ck_gemm_spec::KOPadding;
|
||||
case GemmSpecialization::MNOPadding:
|
||||
return ck_gemm_spec::MNOPadding;
|
||||
case GemmSpecialization::MKOPadding:
|
||||
return ck_gemm_spec::MKOPadding;
|
||||
case GemmSpecialization::NKOPadding:
|
||||
return ck_gemm_spec::NKOPadding;
|
||||
case GemmSpecialization::MNKOPadding:
|
||||
return ck_gemm_spec::MNKOPadding;
|
||||
default:
|
||||
throw "Unknown GemmSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -565,30 +519,21 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto version = ALGORITHM.pipeline_version;
|
||||
|
||||
if constexpr(version == BlockGemmPipelineVersion::V1)
|
||||
using ck_block_gemm = ck::BlockGemmPipelineVersion;
|
||||
switch(version)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V2)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V3)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v3;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V4)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V5)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v5;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown BlockGemmPipelineVersion");
|
||||
case BlockGemmPipelineVersion::V1:
|
||||
return ck_block_gemm::v1;
|
||||
case BlockGemmPipelineVersion::V2:
|
||||
return ck_block_gemm::v2;
|
||||
case BlockGemmPipelineVersion::V3:
|
||||
return ck_block_gemm::v3;
|
||||
case BlockGemmPipelineVersion::V4:
|
||||
return ck_block_gemm::v4;
|
||||
case BlockGemmPipelineVersion::V5:
|
||||
return ck_block_gemm::v5;
|
||||
default:
|
||||
throw "Unknown BlockGemmPipelineVersion";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -596,26 +541,19 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.fwd_specialization;
|
||||
|
||||
if constexpr(specialization == ConvFwdSpecialization::DEFAULT)
|
||||
using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
switch(specialization)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown ConvFwdSpecialization");
|
||||
case ConvFwdSpecialization::DEFAULT:
|
||||
return ck_conv_spec::Default;
|
||||
case ConvFwdSpecialization::FILTER_1X1_PAD0:
|
||||
return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0:
|
||||
return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvFwdSpecialization::FILTER_3x3:
|
||||
return ck_conv_spec::Filter3x3;
|
||||
default:
|
||||
throw "Unknown ConvFwdSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user