Change if-else statements into switch in conv factory.

This commit is contained in:
Ville Pietilä
2025-11-04 10:57:50 +00:00
parent adf0a80290
commit 69a93a57f0

View File

@@ -290,49 +290,44 @@ struct BlockGemmSpec
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr BlockGemmSpec SetBlockGemm()
consteval BlockGemmSpec SetBlockGemm()
{
constexpr auto& BG = ALGORITHM.block_gemm;
ck::BlockGemmPipelineScheduler scheduler;
ck::BlockGemmPipelineVersion version;
if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE)
switch(BG.scheduler)
{
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
}
else if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE)
{
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineScheduler");
case BlockGemmPipelineScheduler::INTRAWAVE:
scheduler = ck::BlockGemmPipelineScheduler::Intrawave;
break;
case BlockGemmPipelineScheduler::INTERWAVE:
scheduler = ck::BlockGemmPipelineScheduler::Interwave;
break;
default:
throw "Unknown BlockGemmPipelineScheduler";
}
if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V1)
switch(BG.pipeline_version)
{
version = ck::BlockGemmPipelineVersion::v1;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V2)
{
version = ck::BlockGemmPipelineVersion::v2;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V3)
{
version = ck::BlockGemmPipelineVersion::v3;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V4)
{
version = ck::BlockGemmPipelineVersion::v4;
}
else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V5)
{
version = ck::BlockGemmPipelineVersion::v5;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineVersion");
case BlockGemmPipelineVersion::V1:
version = ck::BlockGemmPipelineVersion::v1;
break;
case BlockGemmPipelineVersion::V2:
version = ck::BlockGemmPipelineVersion::v2;
break;
case BlockGemmPipelineVersion::V3:
version = ck::BlockGemmPipelineVersion::v3;
break;
case BlockGemmPipelineVersion::V4:
version = ck::BlockGemmPipelineVersion::v4;
break;
case BlockGemmPipelineVersion::V5:
version = ck::BlockGemmPipelineVersion::v5;
break;
default:
throw "Unknown BlockGemmPipelineVersion";
}
return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler};
@@ -441,18 +436,15 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::LoopScheduler SetLoopScheduler()
{
constexpr auto loop_scheduler = ALGORITHM.loop_scheduler;
if constexpr(loop_scheduler == LoopScheduler::DEFAULT)
using ck_loop_sched = ck::LoopScheduler;
switch(loop_scheduler)
{
return ck::LoopScheduler::Default;
}
else if constexpr(loop_scheduler == LoopScheduler::INTERWAVE)
{
return ck::LoopScheduler::Interwave;
}
else
{
static_assert(false, "Unknown LoopScheduler");
case LoopScheduler::DEFAULT:
return ck_loop_sched::Default;
case LoopScheduler::INTERWAVE:
return ck_loop_sched::Interwave;
default:
throw "Unknown LoopScheduler";
}
}
@@ -460,29 +452,21 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
{
constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version;
if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V1)
using ck_pipeline = ck::PipelineVersion;
switch(pipeline_version)
{
return ck::PipelineVersion::v1;
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V2)
{
return ck::PipelineVersion::v2;
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V3)
{
static_assert(false, "V3 is used only for stream-K.");
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V4)
{
return ck::PipelineVersion::v4;
}
else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::WEIGHT_ONLY)
{
return ck::PipelineVersion::weight_only;
}
else
{
static_assert(false, "Unknown GridwiseGemmPipelineVersion");
case GridwiseGemmPipelineVersion::V1:
return ck_pipeline::v1;
case GridwiseGemmPipelineVersion::V2:
return ck_pipeline::v2;
case GridwiseGemmPipelineVersion::V4:
return ck_pipeline::v4;
case GridwiseGemmPipelineVersion::WEIGHT_ONLY:
return ck_pipeline::weight_only;
case GridwiseGemmPipelineVersion::V3:
throw "GridwiseGemmPipelineVersion::V3 is used only for stream-K.";
default:
throw "Unknown GridwiseGemmPipelineVersion";
}
}
@@ -490,74 +474,44 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization()
{
constexpr auto gemm_spec = ALGORITHM.gemm_specialization;
using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization;
if constexpr(gemm_spec == GemmSpecialization::Default)
switch(gemm_spec)
{
return ck::tensor_operation::device::GemmSpecialization::Default;
}
else if constexpr(gemm_spec == GemmSpecialization::MPadding)
{
return ck::tensor_operation::device::GemmSpecialization::MPadding;
}
else if constexpr(gemm_spec == GemmSpecialization::NPadding)
{
return ck::tensor_operation::device::GemmSpecialization::NPadding;
}
else if constexpr(gemm_spec == GemmSpecialization::KPadding)
{
return ck::tensor_operation::device::GemmSpecialization::KPadding;
}
else if constexpr(gemm_spec == GemmSpecialization::MNPadding)
{
return ck::tensor_operation::device::GemmSpecialization::MNPadding;
}
else if constexpr(gemm_spec == GemmSpecialization::MKPadding)
{
return ck::tensor_operation::device::GemmSpecialization::MKPadding;
}
else if constexpr(gemm_spec == GemmSpecialization::NKPadding)
{
return ck::tensor_operation::device::GemmSpecialization::NKPadding;
}
else if constexpr(gemm_spec == GemmSpecialization::MNKPadding)
{
return ck::tensor_operation::device::GemmSpecialization::MNKPadding;
}
else if constexpr(gemm_spec == GemmSpecialization::OPadding)
{
return ck::tensor_operation::device::GemmSpecialization::OPadding;
}
else if constexpr(gemm_spec == GemmSpecialization::MOPadding)
{
return ck::tensor_operation::device::GemmSpecialization::MOPadding;
}
else if constexpr(gemm_spec == GemmSpecialization::NOPadding)
{
return ck::tensor_operation::device::GemmSpecialization::NOPadding;
}
else if constexpr(gemm_spec == GemmSpecialization::KOPadding)
{
return ck::tensor_operation::device::GemmSpecialization::KOPadding;
}
else if constexpr(gemm_spec == GemmSpecialization::MNOPadding)
{
return ck::tensor_operation::device::GemmSpecialization::MNOPadding;
}
else if constexpr(gemm_spec == GemmSpecialization::MKOPadding)
{
return ck::tensor_operation::device::GemmSpecialization::MKOPadding;
}
else if constexpr(gemm_spec == GemmSpecialization::NKOPadding)
{
return ck::tensor_operation::device::GemmSpecialization::NKOPadding;
}
else if constexpr(gemm_spec == GemmSpecialization::MNKOPadding)
{
return ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
}
else
{
static_assert(false, "Unknown GemmSpecialization");
case GemmSpecialization::Default:
return ck_gemm_spec::Default;
case GemmSpecialization::MPadding:
return ck_gemm_spec::MPadding;
case GemmSpecialization::NPadding:
return ck_gemm_spec::NPadding;
case GemmSpecialization::KPadding:
return ck_gemm_spec::KPadding;
case GemmSpecialization::MNPadding:
return ck_gemm_spec::MNPadding;
case GemmSpecialization::MKPadding:
return ck_gemm_spec::MKPadding;
case GemmSpecialization::NKPadding:
return ck_gemm_spec::NKPadding;
case GemmSpecialization::MNKPadding:
return ck_gemm_spec::MNKPadding;
case GemmSpecialization::OPadding:
return ck_gemm_spec::OPadding;
case GemmSpecialization::MOPadding:
return ck_gemm_spec::MOPadding;
case GemmSpecialization::NOPadding:
return ck_gemm_spec::NOPadding;
case GemmSpecialization::KOPadding:
return ck_gemm_spec::KOPadding;
case GemmSpecialization::MNOPadding:
return ck_gemm_spec::MNOPadding;
case GemmSpecialization::MKOPadding:
return ck_gemm_spec::MKOPadding;
case GemmSpecialization::NKOPadding:
return ck_gemm_spec::NKOPadding;
case GemmSpecialization::MNKOPadding:
return ck_gemm_spec::MNKOPadding;
default:
throw "Unknown GemmSpecialization";
}
}
@@ -565,30 +519,21 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
{
constexpr auto version = ALGORITHM.pipeline_version;
if constexpr(version == BlockGemmPipelineVersion::V1)
using ck_block_gemm = ck::BlockGemmPipelineVersion;
switch(version)
{
return ck::BlockGemmPipelineVersion::v1;
}
else if constexpr(version == BlockGemmPipelineVersion::V2)
{
return ck::BlockGemmPipelineVersion::v2;
}
else if constexpr(version == BlockGemmPipelineVersion::V3)
{
return ck::BlockGemmPipelineVersion::v3;
}
else if constexpr(version == BlockGemmPipelineVersion::V4)
{
return ck::BlockGemmPipelineVersion::v4;
}
else if constexpr(version == BlockGemmPipelineVersion::V5)
{
return ck::BlockGemmPipelineVersion::v5;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineVersion");
case BlockGemmPipelineVersion::V1:
return ck_block_gemm::v1;
case BlockGemmPipelineVersion::V2:
return ck_block_gemm::v2;
case BlockGemmPipelineVersion::V3:
return ck_block_gemm::v3;
case BlockGemmPipelineVersion::V4:
return ck_block_gemm::v4;
case BlockGemmPipelineVersion::V5:
return ck_block_gemm::v5;
default:
throw "Unknown BlockGemmPipelineVersion";
}
}
@@ -596,26 +541,19 @@ template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization()
{
constexpr auto specialization = ALGORITHM.fwd_specialization;
if constexpr(specialization == ConvFwdSpecialization::DEFAULT)
using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization;
switch(specialization)
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
}
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
}
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
}
else if constexpr(specialization == ConvFwdSpecialization::FILTER_3x3)
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3;
}
else
{
static_assert(false, "Unknown ConvFwdSpecialization");
case ConvFwdSpecialization::DEFAULT:
return ck_conv_spec::Default;
case ConvFwdSpecialization::FILTER_1X1_PAD0:
return ck_conv_spec::Filter1x1Pad0;
case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0:
return ck_conv_spec::Filter1x1Stride1Pad0;
case ConvFwdSpecialization::FILTER_3x3:
return ck_conv_spec::Filter3x3;
default:
throw "Unknown ConvFwdSpecialization";
}
}