diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 35bc0cf5eb..336e668099 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -299,35 +299,23 @@ consteval BlockGemmSpec SetBlockGemm() switch(BG.scheduler) { - case BlockGemmPipelineScheduler::INTRAWAVE: - scheduler = ck::BlockGemmPipelineScheduler::Intrawave; - break; - case BlockGemmPipelineScheduler::INTERWAVE: - scheduler = ck::BlockGemmPipelineScheduler::Interwave; - break; - default: - throw "Unknown BlockGemmPipelineScheduler"; + case BlockGemmPipelineScheduler::INTRAWAVE: + scheduler = ck::BlockGemmPipelineScheduler::Intrawave; + break; + case BlockGemmPipelineScheduler::INTERWAVE: + scheduler = ck::BlockGemmPipelineScheduler::Interwave; + break; + default: throw "Unknown BlockGemmPipelineScheduler"; } switch(BG.pipeline_version) { - case BlockGemmPipelineVersion::V1: - version = ck::BlockGemmPipelineVersion::v1; - break; - case BlockGemmPipelineVersion::V2: - version = ck::BlockGemmPipelineVersion::v2; - break; - case BlockGemmPipelineVersion::V3: - version = ck::BlockGemmPipelineVersion::v3; - break; - case BlockGemmPipelineVersion::V4: - version = ck::BlockGemmPipelineVersion::v4; - break; - case BlockGemmPipelineVersion::V5: - version = ck::BlockGemmPipelineVersion::v5; - break; - default: - throw "Unknown BlockGemmPipelineVersion"; + case BlockGemmPipelineVersion::V1: version = ck::BlockGemmPipelineVersion::v1; break; + case BlockGemmPipelineVersion::V2: version = ck::BlockGemmPipelineVersion::v2; break; + case BlockGemmPipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break; + case BlockGemmPipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break; + case BlockGemmPipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break; + default: throw "Unknown BlockGemmPipelineVersion"; } return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler}; @@ -436,15 +424,12 @@ template consteval ck::LoopScheduler SetLoopScheduler() { constexpr auto loop_scheduler = ALGORITHM.loop_scheduler; - using ck_loop_sched = ck::LoopScheduler; + using ck_loop_sched = ck::LoopScheduler; switch(loop_scheduler) { - case LoopScheduler::DEFAULT: - return ck_loop_sched::Default; - case LoopScheduler::INTERWAVE: - return ck_loop_sched::Interwave; - default: - throw "Unknown LoopScheduler"; + case LoopScheduler::DEFAULT: return ck_loop_sched::Default; + case LoopScheduler::INTERWAVE: return ck_loop_sched::Interwave; + default: throw "Unknown LoopScheduler"; } } @@ -452,21 +437,16 @@ template consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() { constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version; - using ck_pipeline = ck::PipelineVersion; + using ck_pipeline = ck::PipelineVersion; switch(pipeline_version) { - case GridwiseGemmPipelineVersion::V1: - return ck_pipeline::v1; - case GridwiseGemmPipelineVersion::V2: - return ck_pipeline::v2; - case GridwiseGemmPipelineVersion::V4: - return ck_pipeline::v4; - case GridwiseGemmPipelineVersion::WEIGHT_ONLY: - return ck_pipeline::weight_only; - case GridwiseGemmPipelineVersion::V3: - throw "GridwiseGemmPipelineVersion::V3 is used only for stream-K."; - default: - throw "Unknown GridwiseGemmPipelineVersion"; + case GridwiseGemmPipelineVersion::V1: return ck_pipeline::v1; + case GridwiseGemmPipelineVersion::V2: return ck_pipeline::v2; + case GridwiseGemmPipelineVersion::V4: return ck_pipeline::v4; + case GridwiseGemmPipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only; + case GridwiseGemmPipelineVersion::V3: + throw "GridwiseGemmPipelineVersion::V3 is used only for stream-K."; + default: throw "Unknown GridwiseGemmPipelineVersion"; } } @@ -474,44 +454,27 @@ template consteval ck::tensor_operation::device::GemmSpecialization SetGemmSpecialization() { constexpr auto gemm_spec = ALGORITHM.gemm_specialization; - using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization; + using ck_gemm_spec = ck::tensor_operation::device::GemmSpecialization; switch(gemm_spec) { - case GemmSpecialization::Default: - return ck_gemm_spec::Default; - case GemmSpecialization::MPadding: - return ck_gemm_spec::MPadding; - case GemmSpecialization::NPadding: - return ck_gemm_spec::NPadding; - case GemmSpecialization::KPadding: - return ck_gemm_spec::KPadding; - case GemmSpecialization::MNPadding: - return ck_gemm_spec::MNPadding; - case GemmSpecialization::MKPadding: - return ck_gemm_spec::MKPadding; - case GemmSpecialization::NKPadding: - return ck_gemm_spec::NKPadding; - case GemmSpecialization::MNKPadding: - return ck_gemm_spec::MNKPadding; - case GemmSpecialization::OPadding: - return ck_gemm_spec::OPadding; - case GemmSpecialization::MOPadding: - return ck_gemm_spec::MOPadding; - case GemmSpecialization::NOPadding: - return ck_gemm_spec::NOPadding; - case GemmSpecialization::KOPadding: - return ck_gemm_spec::KOPadding; - case GemmSpecialization::MNOPadding: - return ck_gemm_spec::MNOPadding; - case GemmSpecialization::MKOPadding: - return ck_gemm_spec::MKOPadding; - case GemmSpecialization::NKOPadding: - return ck_gemm_spec::NKOPadding; - case GemmSpecialization::MNKOPadding: - return ck_gemm_spec::MNKOPadding; - default: - throw "Unknown GemmSpecialization"; + case GemmSpecialization::Default: return ck_gemm_spec::Default; + case GemmSpecialization::MPadding: return ck_gemm_spec::MPadding; + case GemmSpecialization::NPadding: return ck_gemm_spec::NPadding; + case GemmSpecialization::KPadding: return ck_gemm_spec::KPadding; + case GemmSpecialization::MNPadding: return ck_gemm_spec::MNPadding; + case GemmSpecialization::MKPadding: return ck_gemm_spec::MKPadding; + case GemmSpecialization::NKPadding: return ck_gemm_spec::NKPadding; + case GemmSpecialization::MNKPadding: return ck_gemm_spec::MNKPadding; + case GemmSpecialization::OPadding: return ck_gemm_spec::OPadding; + case GemmSpecialization::MOPadding: return ck_gemm_spec::MOPadding; + case GemmSpecialization::NOPadding: return ck_gemm_spec::NOPadding; + case GemmSpecialization::KOPadding: return ck_gemm_spec::KOPadding; + case GemmSpecialization::MNOPadding: return ck_gemm_spec::MNOPadding; + case GemmSpecialization::MKOPadding: return ck_gemm_spec::MKOPadding; + case GemmSpecialization::NKOPadding: return ck_gemm_spec::NKOPadding; + case GemmSpecialization::MNKOPadding: return ck_gemm_spec::MNKOPadding; + default: throw "Unknown GemmSpecialization"; } } @@ -519,21 +482,15 @@ template consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() { constexpr auto version = ALGORITHM.pipeline_version; - using ck_block_gemm = ck::BlockGemmPipelineVersion; + using ck_block_gemm = ck::BlockGemmPipelineVersion; switch(version) { - case BlockGemmPipelineVersion::V1: - return ck_block_gemm::v1; - case BlockGemmPipelineVersion::V2: - return ck_block_gemm::v2; - case BlockGemmPipelineVersion::V3: - return ck_block_gemm::v3; - case BlockGemmPipelineVersion::V4: - return ck_block_gemm::v4; - case BlockGemmPipelineVersion::V5: - return ck_block_gemm::v5; - default: - throw "Unknown BlockGemmPipelineVersion"; + case BlockGemmPipelineVersion::V1: return ck_block_gemm::v1; + case BlockGemmPipelineVersion::V2: return ck_block_gemm::v2; + case BlockGemmPipelineVersion::V3: return ck_block_gemm::v3; + case BlockGemmPipelineVersion::V4: return ck_block_gemm::v4; + case BlockGemmPipelineVersion::V5: return ck_block_gemm::v5; + default: throw "Unknown BlockGemmPipelineVersion"; } } @@ -541,19 +498,14 @@ template consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization() { constexpr auto specialization = ALGORITHM.fwd_specialization; - using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization; + using ck_conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization; switch(specialization) { - case ConvFwdSpecialization::DEFAULT: - return ck_conv_spec::Default; - case ConvFwdSpecialization::FILTER_1X1_PAD0: - return ck_conv_spec::Filter1x1Pad0; - case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: - return ck_conv_spec::Filter1x1Stride1Pad0; - case ConvFwdSpecialization::FILTER_3x3: - return ck_conv_spec::Filter3x3; - default: - throw "Unknown ConvFwdSpecialization"; + case ConvFwdSpecialization::DEFAULT: return ck_conv_spec::Default; + case ConvFwdSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; + case ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; + case ConvFwdSpecialization::FILTER_3x3: return ck_conv_spec::Filter3x3; + default: throw "Unknown ConvFwdSpecialization"; } } diff --git a/experimental/builder/test/utils/ckb_conv_test_configs.hpp b/experimental/builder/test/utils/ckb_conv_test_configs.hpp index 0db8af2dc0..02c3dfec9b 100644 --- a/experimental/builder/test/utils/ckb_conv_test_configs.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_configs.hpp @@ -12,124 +12,125 @@ namespace ck_tile::builder::test_utils { using namespace ck_tile::builder; using namespace test; -constexpr BlockTransferABC FwdBlockTransfer_4x64_1 {.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 32, - .n_block = 1, - .n_wave_per_xdl = 8}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 2, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = false}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; +constexpr BlockTransferABC FwdBlockTransfer_4x64_1{ + .block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 8}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 2, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = false}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = false}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; -constexpr BlockTransferABC FwdBlockTransfer_4x16x1 {.block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 16, - .n_block = 1, - .n_wave_per_xdl = 4}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 8, - .lds_dst_scalar_per_vector = 8, - .is_direct_load = false, - .lds_padding = true}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; +constexpr BlockTransferABC FwdBlockTransfer_4x16x1{ + .block_transfer_a = {.k0 = 4, .m_n = 16, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 16, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 16, + .n_block = 1, + .n_wave_per_xdl = 4}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 8, + .lds_dst_scalar_per_vector = 8, + .is_direct_load = false, + .lds_padding = true}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; -constexpr BlockTransferABC FwdBlockTransfer_4x32x1 {.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1}, - .block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1}, - .thread_cluster_dims_c = {.m_block = 1, - .m_wave_per_xdl = 32, - .n_block = 1, - .n_wave_per_xdl = 4}, - .lds_transfer_a = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .lds_transfer_b = {.src_vector_dim = 2, - .src_scalar_per_vector = 16, - .lds_dst_scalar_per_vector = 16, - .is_direct_load = false, - .lds_padding = true}, - .epilogue_c = {.m_per_wave_per_shuffle = 1, - .n_per_wave_per_shuffle = 1, - .scalar_per_vector = 8}, - .block_transfer_access_order_a = {1, 0, 2}, - .block_transfer_access_order_b = {1, 0, 2}, - .src_access_order_a = {1, 0, 2}, - .src_access_order_b = {1, 0, 2}}; +constexpr BlockTransferABC FwdBlockTransfer_4x32x1{ + .block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1}, + .block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1}, + .thread_cluster_dims_c = {.m_block = 1, + .m_wave_per_xdl = 32, + .n_block = 1, + .n_wave_per_xdl = 4}, + .lds_transfer_a = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .lds_transfer_b = {.src_vector_dim = 2, + .src_scalar_per_vector = 16, + .lds_dst_scalar_per_vector = 16, + .is_direct_load = false, + .lds_padding = true}, + .epilogue_c = {.m_per_wave_per_shuffle = 1, + .n_per_wave_per_shuffle = 1, + .scalar_per_vector = 8}, + .block_transfer_access_order_a = {1, 0, 2}, + .block_transfer_access_order_b = {1, 0, 2}, + .src_access_order_a = {1, 0, 2}, + .src_access_order_b = {1, 0, 2}}; -constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave {.ak1 = 8, - .bk1 = 8, - .m_per_xdl = 32, - .n_per_xdl = 32, - .m_xdl_per_wave = 4, - .n_xdl_per_wave = 4}; +constexpr GridwiseXdlGemm FwdGemmParams_Xdl_4x4_per_wave{ + .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}; -constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave {.ak1 = 8, - .bk1 = 8, - .m_per_xdl = 32, - .n_per_xdl = 32, - .m_xdl_per_wave = 2, - .n_xdl_per_wave = 1}; +constexpr GridwiseXdlGemm FwdGemmParams_Xdl_2x1_per_wave{ + .ak1 = 8, .bk1 = 8, .m_per_xdl = 32, .n_per_xdl = 32, .m_xdl_per_wave = 2, .n_xdl_per_wave = 1}; -constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave {.k1 = 8, - .m_per_wmma = 32, - .n_per_wmma = 32, - .m_wmma_per_wave = 2, - .n_wmma_per_wave = 1, - .pipeline_version = GridwiseGemmPipelineVersion::V1}; +constexpr GridwiseWmmaGemm FwdGemmParams_Wmma_2x1_per_wave{.k1 = 8, + .m_per_wmma = 32, + .n_per_wmma = 32, + .m_wmma_per_wave = 2, + .n_wmma_per_wave = 1, + .pipeline_version = + GridwiseGemmPipelineVersion::V1}; -constexpr ThreadBlock FwdThreadBlock_256x256x32 {.block_size = 256, - .tile_size = {.m = 256, .n = 256, .k = 32}}; +constexpr ThreadBlock FwdThreadBlock_256x256x32{.block_size = 256, + .tile_size = {.m = 256, .n = 256, .k = 32}}; -constexpr ThreadBlock FwdThreadBlock_128x128x32 {.block_size = 256, - .tile_size = {.m = 128, .n = 128, .k = 32}}; +constexpr ThreadBlock FwdThreadBlock_128x128x32{.block_size = 256, + .tile_size = {.m = 128, .n = 128, .k = 32}}; -constexpr ThreadBlock FwdThreadBlock_64x32x32 {.block_size = 64, - .tile_size = {.m = 64, .n = 32, .k = 32}}; +constexpr ThreadBlock FwdThreadBlock_64x32x32{.block_size = 64, + .tile_size = {.m = 64, .n = 32, .k = 32}}; -constexpr ThreadBlock FwdThreadBlock_64x64x64 {.block_size = 128, - .tile_size = {.m = 64, .n = 64, .k = 64}}; +constexpr ThreadBlock FwdThreadBlock_64x64x64{.block_size = 128, + .tile_size = {.m = 64, .n = 64, .k = 64}}; constexpr BlockGemm BlockGemmDesc_v1_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V1, - .scheduler = BlockGemmPipelineScheduler::INTRAWAVE}; + .scheduler = + BlockGemmPipelineScheduler::INTRAWAVE}; constexpr BlockGemm BlockGemmDesc_v2_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V2, - .scheduler = BlockGemmPipelineScheduler::INTRAWAVE}; + .scheduler = + BlockGemmPipelineScheduler::INTRAWAVE}; constexpr BlockGemm BlockGemmDesc_v3_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V3, - .scheduler = BlockGemmPipelineScheduler::INTRAWAVE}; + .scheduler = + BlockGemmPipelineScheduler::INTRAWAVE}; constexpr BlockGemm BlockGemmDesc_v4_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V4, - .scheduler = BlockGemmPipelineScheduler::INTRAWAVE}; + .scheduler = + BlockGemmPipelineScheduler::INTRAWAVE}; constexpr BlockGemm BlockGemmDesc_v5_intrawave = {.pipeline_version = BlockGemmPipelineVersion::V5, - .scheduler = BlockGemmPipelineScheduler::INTRAWAVE}; + .scheduler = + BlockGemmPipelineScheduler::INTRAWAVE}; } // namespace ck_tile::builder::test_utils diff --git a/experimental/builder/test/utils/ckb_conv_test_utils.hpp b/experimental/builder/test/utils/ckb_conv_test_utils.hpp index c04117e8d6..c2c6ae63cf 100644 --- a/experimental/builder/test/utils/ckb_conv_test_utils.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_utils.hpp @@ -14,7 +14,6 @@ namespace ck_tile::builder::test_utils { using namespace ck_tile::builder; using namespace test; - // Common test implementation template constexpr void run_test(const std::vector& kernel_instance_components) @@ -28,7 +27,7 @@ constexpr void run_test(const std::vector& kernel_instance_componen const auto invoker_ptr = instance.MakeInvokerPointer(); EXPECT_NE(invoker_ptr, nullptr); - for (const auto& component : kernel_instance_components) + for(const auto& component : kernel_instance_components) { EXPECT_THAT(kernel_string, ::testing::HasSubstr(component)); }