diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 8ea3e18d65..35bc0cf5eb 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -290,49 +290,44 @@ struct BlockGemmSpec }; template -constexpr BlockGemmSpec SetBlockGemm() +consteval BlockGemmSpec SetBlockGemm() { constexpr auto& BG = ALGORITHM.block_gemm; ck::BlockGemmPipelineScheduler scheduler; ck::BlockGemmPipelineVersion version; - if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE) + switch(BG.scheduler) { - scheduler = ck::BlockGemmPipelineScheduler::Intrawave; - } - else if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE) - { - scheduler = ck::BlockGemmPipelineScheduler::Interwave; - } - else - { - static_assert(false, "Unknown BlockGemmPipelineScheduler"); + case BlockGemmPipelineScheduler::INTRAWAVE: + scheduler = ck::BlockGemmPipelineScheduler::Intrawave; + break; + case BlockGemmPipelineScheduler::INTERWAVE: + scheduler = ck::BlockGemmPipelineScheduler::Interwave; + break; + default: + throw "Unknown BlockGemmPipelineScheduler"; } - if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V1) + switch(BG.pipeline_version) { - version = ck::BlockGemmPipelineVersion::v1; - } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V2) - { - version = ck::BlockGemmPipelineVersion::v2; - } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V3) - { - version = ck::BlockGemmPipelineVersion::v3; - } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V4) - { - version = ck::BlockGemmPipelineVersion::v4; - } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V5) - { - version = ck::BlockGemmPipelineVersion::v5; - } - else - { - static_assert(false, "Unknown BlockGemmPipelineVersion"); + case BlockGemmPipelineVersion::V1: + version = ck::BlockGemmPipelineVersion::v1; + break; + case BlockGemmPipelineVersion::V2: + version = ck::BlockGemmPipelineVersion::v2; + break; + case BlockGemmPipelineVersion::V3: + version = ck::BlockGemmPipelineVersion::v3; + break; + case BlockGemmPipelineVersion::V4: + version = ck::BlockGemmPipelineVersion::v4; + break; + case BlockGemmPipelineVersion::V5: + version = ck::BlockGemmPipelineVersion::v5; + break; + default: + throw "Unknown BlockGemmPipelineVersion"; } return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler}; @@ -441,18 +436,15 @@ template consteval ck::LoopScheduler SetLoopScheduler() { constexpr auto loop_scheduler = ALGORITHM.loop_scheduler; - - if constexpr(loop_scheduler == LoopScheduler::DEFAULT) + using ck_loop_sched = ck::LoopScheduler; + switch(loop_scheduler) { - return ck::LoopScheduler::Default; - } - else if constexpr(loop_scheduler == LoopScheduler::INTERWAVE) - { - return ck::LoopScheduler::Interwave; - } - else - { - static_assert(false, "Unknown LoopScheduler"); + case LoopScheduler::DEFAULT: + return ck_loop_sched::Default; + case LoopScheduler::INTERWAVE: + return ck_loop_sched::Interwave; + default: + throw "Unknown LoopScheduler"; } } @@ -460,29 +452,21 @@ template consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() { constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version; - if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V1) + using ck_pipeline = ck::PipelineVersion; + switch(pipeline_version) { - return ck::PipelineVersion::v1; - } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V2) - { - return ck::PipelineVersion::v2; - } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V3) - { - static_assert(false, "V3 is used only for stream-K."); - } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V4) - { - return ck::PipelineVersion::v4; - } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::WEIGHT_ONLY) - { - return ck::PipelineVersion::weight_only; - } - else - { - static_assert(false, "Unknown GridwiseGemmPipelineVersion"); + case GridwiseGemmPipelineVersion::V1: + return ck_pipeline::v1; + case GridwiseGemmPipelineVersion::V2: + return ck_pipeline::v2; + case GridwiseGemmPipelineVersion::V4: + return ck_pipeline::v4; + case GridwiseGemmPipelineVersion::WEIGHT_ONLY: + return ck_pipeline::weight_only; + case GridwiseGemmPipelineVersion::V3: + throw "GridwiseGemmPipelineVersion::V3 is used only for stream-K."; + default: + throw "Unknown GridwiseGemmPipelineVersion"; } } @@ -490,74 +474,44 @@ template consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization() { constexpr auto gemm_spec = ALGORITHM.gemm_specialization; + using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization; - if constexpr(gemm_spec == GemmSpecialization::Default) + switch(gemm_spec) { - return ck::tensor_operation::device::GemmSpecialization::Default; - } - else if constexpr(gemm_spec == GemmSpecialization::MPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::NPadding) - { - return ck::tensor_operation::device::GemmSpecialization::NPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::KPadding) - { - return ck::tensor_operation::device::GemmSpecialization::KPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MNPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MNPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MKPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MKPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::NKPadding) - { - return ck::tensor_operation::device::GemmSpecialization::NKPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MNKPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MNKPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::OPadding) - { - return ck::tensor_operation::device::GemmSpecialization::OPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::NOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::NOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::KOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::KOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MNOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MNOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MKOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MKOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::NKOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::NKOPadding; - } - else if constexpr(gemm_spec == GemmSpecialization::MNKOPadding) - { - return ck::tensor_operation::device::GemmSpecialization::MNKOPadding; - } - else - { - static_assert(false, "Unknown GemmSpecialization"); + case GemmSpecialization::Default: + return ck_gemm_spec::Default; + case GemmSpecialization::MPadding: + return ck_gemm_spec::MPadding; + case GemmSpecialization::NPadding: + return ck_gemm_spec::NPadding; + case GemmSpecialization::KPadding: + return ck_gemm_spec::KPadding; + case GemmSpecialization::MNPadding: + return ck_gemm_spec::MNPadding; + case GemmSpecialization::MKPadding: + return ck_gemm_spec::MKPadding; + case GemmSpecialization::NKPadding: + return ck_gemm_spec::NKPadding; + case GemmSpecialization::MNKPadding: + return ck_gemm_spec::MNKPadding; + case GemmSpecialization::OPadding: + return ck_gemm_spec::OPadding; + case GemmSpecialization::MOPadding: + return ck_gemm_spec::MOPadding; + case GemmSpecialization::NOPadding: + return ck_gemm_spec::NOPadding; + case GemmSpecialization::KOPadding: + return ck_gemm_spec::KOPadding; + case GemmSpecialization::MNOPadding: + return ck_gemm_spec::MNOPadding; + case GemmSpecialization::MKOPadding: + return ck_gemm_spec::MKOPadding; + case GemmSpecialization::NKOPadding: + return ck_gemm_spec::NKOPadding; + case GemmSpecialization::MNKOPadding: + return ck_gemm_spec::MNKOPadding; + default: + throw "Unknown GemmSpecialization"; } } @@ -565,30 +519,21 @@ template consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() { constexpr auto version = ALGORITHM.pipeline_version; - - if constexpr(version == BlockGemmPipelineVersion::V1) + using ck_block_gemm = ck::BlockGemmPipelineVersion; + switch(version) { - return ck::BlockGemmPipelineVersion::v1; - } - else if constexpr(version == BlockGemmPipelineVersion::V2) - { - return ck::BlockGemmPipelineVersion::v2; - } - else if constexpr(version == BlockGemmPipelineVersion::V3) - { - return ck::BlockGemmPipelineVersion::v3; - } - else if constexpr(version == BlockGemmPipelineVersion::V4) - { - return ck::BlockGemmPipelineVersion::v4; - } - else if constexpr(version == BlockGemmPipelineVersion::V5) - { - return ck::BlockGemmPipelineVersion::v5; - } - else - { - static_assert(false, "Unknown BlockGemmPipelineVersion"); + case BlockGemmPipelineVersion::V1: + return ck_block_gemm::v1; + case BlockGemmPipelineVersion::V2: + return ck_block_gemm::v2; + case BlockGemmPipelineVersion::V3: + return ck_block_gemm::v3; + case BlockGemmPipelineVersion::V4: + return ck_block_gemm::v4; + case BlockGemmPipelineVersion::V5: + return ck_block_gemm::v5; + default: + throw "Unknown BlockGemmPipelineVersion"; } } @@ -596,26 +541,19 @@ template consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization() { constexpr auto specialization = ALGORITHM.fwd_specialization; - - if constexpr(specialization == ConvFwdSpecialization::DEFAULT) + using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization; + switch(specialization) { - return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - } - else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_PAD0) - { - return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - } - else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0) - { - return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - } - else if constexpr(specialization == ConvFwdSpecialization::FILTER_3x3) - { - return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3; - } - else - { - static_assert(false, "Unknown ConvFwdSpecialization"); + case ConvFwdSpecialization::DEFAULT: + return ck_conv_spec::Default; + case ConvFwdSpecialization::FILTER_1X1_PAD0: + return ck_conv_spec::Filter1x1Pad0; + case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: + return ck_conv_spec::Filter1x1Stride1Pad0; + case ConvFwdSpecialization::FILTER_3x3: + return ck_conv_spec::Filter3x3; + default: + throw "Unknown ConvFwdSpecialization"; } }