diff --git a/example/ck_builder/01_conv_2d_fwd/ckb_example_conv_fwd_2d.cpp b/example/ck_builder/01_conv_2d_fwd/ckb_example_conv_fwd_2d.cpp index fcdb9e313b..abeaebe4ff 100644 --- a/example/ck_builder/01_conv_2d_fwd/ckb_example_conv_fwd_2d.cpp +++ b/example/ck_builder/01_conv_2d_fwd/ckb_example_conv_fwd_2d.cpp @@ -2,15 +2,11 @@ #include "ck_tile/builder/conv_builder.hpp" #include "../utils/types.hpp" -int main() { - +int main() +{ namespace ckb = ck_tile::builder; namespace ckb_examples = ck_tile::builder::examples; - constexpr size_t m_tile = 128; - constexpr size_t n_tile = 128; - constexpr size_t k_tile = 32; - constexpr ckb_examples::ConvSignature FwdConvSignature { .spatial_dim = 2, @@ -20,10 +16,20 @@ int main() { }; static_assert(ckb::ValidConvSignature); + // To get valid configuration parameters, refer to "device_grouped_conv_fwd_xdl_comp_instance.hpp". + // This file contains the current instances of the kernel DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3. + // Currently the build has this kernel hard-coded. + // In the future, we may need builders per kernel type since they typically have slightly different parameters. + constexpr ckb::ThreadBlock FwdThreadBlock { .block_size = 256, - .submatrix = {.m = m_tile, .n = n_tile, .k = k_tile} + .submatrix = {.m = 256, .n = 256, .k = 32} // Tile sizes + }; + + constexpr ckb::ConvTuningParams FwdTuningParams + { + .ak1 = 8, .bk1 = 8, .m_per_xdl=32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4 }; constexpr ckb_examples::BlockTransfer FwdBlockTransfer @@ -33,35 +39,36 @@ int main() { .thread_cluster_dims_c = { .m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8}, .vector_transfer_a = { - .src_vector_dim = 2, .src_scaler_per_vector = 8, .dest_scaler_per_vector_k1 = 8, .add_extra = true}, + .src_vector_dim = 2, .src_scaler_per_vector = 2, .dest_scaler_per_vector_k1 = 8, .add_extra = false}, .vector_transfer_b = { - .src_vector_dim = 2, .src_scaler_per_vector = 8, .dest_scaler_per_vector_k1 = 8, .add_extra = true}, + .src_vector_dim = 2, .src_scaler_per_vector = 8, .dest_scaler_per_vector_k1 = 8, .add_extra = false}, .vector_transfer_c = { - .m_xdl_per_wave_per_shuffle = 1, .n_xdl_per_wave_per_shuffle = 2, .scaler_per_vector = 8}, + .m_xdl_per_wave_per_shuffle = 1, .n_xdl_per_wave_per_shuffle = 1, .scaler_per_vector = 8}, .a_thread_cluster_access_order = {1, 0, 2}, .b_thread_cluster_access_order = {1, 0, 2}, .a_source_access_order = {1, 0, 2}, .b_source_access_order = {1, 0, 2} }; - constexpr ckb_examples::ConvAlgorithm FwdConvAlgorithm { .thread_block = FwdThreadBlock, - .tuning_params = {.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 1, .n_xdl_per_wave = 4}, + .tuning_params = FwdTuningParams, .block_transfer = FwdBlockTransfer, - .pipeline_version = ckb::BlockGemmPipelineVersion::V1, + .pipeline_version = ckb::BlockGemmPipelineVersion::V4, }; using Builder = ckb::ConvBuilder; const auto kernel_string = Builder::Instance::TypeString(); - std::cout << "Generated kernel: " << kernel_string << std::endl; // The invoker is the entrypoint to launch the kernel. - // Creating the invoker triggers the validation of the builder configuration. - //auto invoker = Builder::Instance::MakeInvoker(); - //(void)invoker; + // Creating the invoker triggers the validation of the builder configuration, + // that is, the combination of all builder parameters is checked at compile time. + auto invoker = Builder::Instance::MakeInvoker(); + + // TODO: Prepare actual data and launch the kernel. + (void)invoker; return 0; } diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp index 053fe0fa3f..910bb16469 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp @@ -50,6 +50,8 @@ template concept ConvTuningDescriptor = requires(T t) { { t.ak1 } -> std::convertible_to; { t.bk1 } -> std::convertible_to; + { t.m_per_xdl } -> std::convertible_to; + { t.n_per_xdl } -> std::convertible_to; { t.m_xdl_per_wave } -> std::convertible_to; { t.n_xdl_per_wave } -> std::convertible_to; }; @@ -60,6 +62,8 @@ struct ConvTuningParams // NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!! int ak1 = 0; int bk1 = 0; + int m_per_xdl = 0; + int n_per_xdl = 0; int m_xdl_per_wave = 0; int n_xdl_per_wave = 0; };