From e17ac63e4a417ad420df4be66519efe8a9c30a16 Mon Sep 17 00:00:00 2001 From: Mateusz Ozga Date: Tue, 13 May 2025 10:50:25 +0000 Subject: [PATCH] Fix CI --- .../unary_element_wise_operation.hpp | 8 +- .../ops/epilogue/cshuffle_epilogue.hpp | 16 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 143 +++++++----------- .../test_multiple_d_gemm_ut_cases.inc | 6 +- .../test_multiple_d_gemm_util.hpp | 34 ++++- .../test_grouped_convnd_bwd_weight.cpp | 75 +++++---- 6 files changed, 137 insertions(+), 145 deletions(-) diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 4d34d8c0bf..a8046a09ac 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -1496,8 +1496,10 @@ struct ElementWiseAdd * @note [return] Perform element-wise addition and store the result in 'r' */ template - CK_TILE_DEVICE auto operator()(ResT& r, const ParamT& a, const ParamT& b, const ParamT& c) const - -> void + CK_TILE_DEVICE auto operator()(ResT& r, + [[maybe_unused]] const ParamT& a, + [[maybe_unused]] const ParamT& b, + [[maybe_unused]] const ParamT& c) const -> void { r = a + b + c; } @@ -1536,7 +1538,7 @@ struct ElementWiseMul CK_TILE_DEVICE auto operator()(ResT& r, const ParamT& a, const ParamT& b, const ParamT& c) const -> void { - r = a + b + c; + r = a * b * c; } /** diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 2b1e811412..38ec2c6996 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -154,7 +154,7 @@ struct CShuffleEpilogue template CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, - onst DsDramWindows& ds_dram_window, + const DsDramWindows& ds_dram_window, void* p_smem) { @@ -190,10 +190,6 @@ struct CShuffleEpilogue [&](auto idx) { return make_tile_window(ds_dram_window[idx], dram_tile_distribution); }, number{}); - using elemenet_wise_output_t = - decltype(load_tile(make_tile_window(out_lds_window, dram_tile_distribution))); - elemenet_wise_output_t elemenet_wise_output; - constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; @@ -215,26 +211,26 @@ struct CShuffleEpilogue store_tile(in_lds_window, c_warp_in_tensor_casted); block_sync_lds(); - const auto c_out_tensor = + auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); const auto ds_tensor = generate_tuple( [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); const auto c_ds_tiles = concat_tuple_of_reference( - tie(elemenet_wise_output, c_out_tensor), + tie(c_out_tensor, c_out_tensor), generate_tie( - [&](auto i) -> const auto& { return ds_tensor[i]; }, number{})); + [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); tile_elementwise_in_out_unpack_tuple(typename Problem::CDElementwise{}, c_ds_tiles); if constexpr(MemoryOperation == memory_operation_enum::set) { - store_tile(out_dram_window, c_out_tensor); + store_tile(out_dram_window, elemenet_wise_output); } else { - update_tile(out_dram_window, c_out_tensor); + update_tile(out_dram_window, elemenet_wise_output); } if constexpr(iAccess != num_access - 1) { diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 51edfc9b8c..e31547d484 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -9,16 +9,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/host/concat.hpp" -#include "ck_tile/core/utility/env.hpp" namespace ck_tile { -/// @brief The GEMM kernel host arguments. -/// -/// @par Overview -/// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments -/// object. It contain all necessary information required to build proper kernel argument -/// and launch kernel on GPU. template struct GemmHostArgs { @@ -64,75 +57,23 @@ struct GemmHostArgs index_t k_batch; }; -/// @brief The GEMM kernel device arguments. template > struct GemmKernelArgs { - /// @brief The A input tensor's pointer to device memory. const void* a_ptr; - /// @brief The B input tensor's pointer to device memory. const void* b_ptr; - /// @brief The Ds input tensor's tuple to device memory. const DType ds_ptr; - /// @brief The C output tensor's pointer to device memory. void* c_ptr; - /// @brief GEMM's M dimension size. index_t M; - /// @brief GEMM's N dimension size. index_t N; - /// @brief GEMM's K dimension size. index_t K; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of A tensor. index_t stride_A; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of B tensor. index_t stride_B; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of Ds tensor. const index_t* stride_Ds; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of C tensor. index_t stride_C; index_t k_batch; }; -/// @brief The GEMM kernel template. -/// -/// @paragraph Overview Overview -/// This class provides the generic matrix multiplication kernel template. By semantic -/// division of GEMM algorithm into following parts we achieve flexible, versatile -/// and robust kernel implementation. -/// -/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() -/// function call operator" which determines the work scope of each workgroup. -/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm. -/// This is the place where each workgroup is loading data from global memory and -/// carrying out dot products. -/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation -/// responsible for storing results to global memory. This is also the place where -/// any additional operator fusion may take place. -/// -/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_ -/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all -/// internal details of those functional parts. You can think of it like both gemm and -/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover -/// the policy is responsible for definition of all necessary data layouts and thread's -/// work distribution. -/// -/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into the -/// output data tile to be calculated. It determines the workgroup to -/// data relationship (or in other words - which data would be -/// processed and calculated by which workgroup). -/// @tparam GemmPipeline_ The type of class which provides the core part of matrix -/// multiplication. This class should provide implementation of data -/// loading from global memory and performing block-wise matrix -/// multiplication. You can think of it as a work done by single -/// workgroup point of view. -/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix -/// multiplication implementation. It is responsible for storing -/// results calculated by @ref GemmPipeline_ "GemmPipeline" to -/// the output C tensor in global memory. template struct GemmKernel { @@ -580,10 +521,9 @@ struct GemmKernel } }(); - return make_tuple(a_tensor_view, - b_tensor_view, - generate_tuple(d_tensor_view, number{}), - c_tensor_view); + const auto& ds_tensor_view = generate_tuple(d_tensor_view, number{}); + + return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view); } template @@ -740,7 +680,9 @@ struct GemmKernel * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * + * @tparam DstInMemOp Destination memory operation (default: set). */ + template CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, const BDataType* b_ptr, const DsGridPointer ds_ptr, @@ -752,10 +694,8 @@ struct GemmKernel const index_t block_idx_n) { // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr,c_ptr, kargs, splitk_batch_offset); - + const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -774,9 +714,12 @@ struct GemmKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I3); - - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); } /** @@ -795,7 +738,9 @@ struct GemmKernel * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * + * @tparam DstInMemOp Destination memory operation (default: set). */ + template CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, const BDataType* b_ptr, const DsGridPointer ds_ptr, @@ -808,10 +753,8 @@ struct GemmKernel const index_t block_idx_n) { // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, kargs, splitk_batch_offset); - + const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -829,10 +772,12 @@ struct GemmKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I3); - - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); } CK_TILE_DEVICE void operator()(GemmKernelArgs& kargs) const @@ -849,7 +794,6 @@ struct GemmKernel static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; const BDataType* b_ptr = static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - // const DsGridPointer* ds_ptr = reinterpret_cast(kargs.ds_ptr); CDataType* c_ptr = static_cast(kargs.c_ptr); @@ -859,9 +803,7 @@ struct GemmKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) + if(kargs.k_batch == 1) { RunGemm2LDS(a_ptr, b_ptr, @@ -874,12 +816,27 @@ struct GemmKernel i_m, i_n); } + else + { + if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(a_ptr, + b_ptr, + kargs.ds_ptr, + c_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } } else { - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) + if(kargs.k_batch == 1) { RunGemm(a_ptr, b_ptr, @@ -891,6 +848,22 @@ struct GemmKernel i_m, i_n); } + else + { + if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm(a_ptr, + b_ptr, + kargs.ds_ptr, + c_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } } } }; diff --git a/test/ck_tile/multiple_d_gemm/test_multiple_d_gemm_ut_cases.inc b/test/ck_tile/multiple_d_gemm/test_multiple_d_gemm_ut_cases.inc index efb69b3be3..fb3263a585 100644 --- a/test/ck_tile/multiple_d_gemm/test_multiple_d_gemm_ut_cases.inc +++ b/test/ck_tile/multiple_d_gemm/test_multiple_d_gemm_ut_cases.inc @@ -2,8 +2,8 @@ TYPED_TEST(TestCkTileMultipleDGemm, Basic) { - constexpr int M = 256; - constexpr int N = 512; - constexpr int K = 512; + constexpr int M = 3840; + constexpr int N = 4096; + constexpr int K = 4096; this->Run(M, N, K); } diff --git a/test/ck_tile/multiple_d_gemm/test_multiple_d_gemm_util.hpp b/test/ck_tile/multiple_d_gemm/test_multiple_d_gemm_util.hpp index 9ce4a739d9..814cc98b4c 100644 --- a/test/ck_tile/multiple_d_gemm/test_multiple_d_gemm_util.hpp +++ b/test/ck_tile/multiple_d_gemm/test_multiple_d_gemm_util.hpp @@ -120,10 +120,13 @@ class TestCkTileMultipleDGemm : public ::testing::Test float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem>; + UniversalGemmProblem::TransposeC, + memory_operation>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -180,11 +184,29 @@ class TestCkTileMultipleDGemm : public ::testing::Test return ave_time; }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + if(has_hot_loop) { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index 92628bfd1d..27553ce6fa 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -175,52 +175,51 @@ TYPED_TEST_SUITE(TestGroupedConvndBwdWeight3d, KernelTypes3d); TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D) { - // this->conv_params.clear(); - // this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); - // this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); - // this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); - // this->conv_params.push_back({1, 1, 1, 1, 32, {3}, {32}, {1}, {1}, {1}, {1}}); - // this->conv_params.push_back({1, 1, 1, 64, 3, {3}, {32}, {1}, {1}, {1}, {1}}); - // this->conv_params.push_back({1, 1, 1, 1, 1, {3}, {32}, {1}, {1}, {1}, {1}}); - // this->Run(); + this->conv_params.clear(); + this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); + this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); + this->conv_params.push_back({1, 1, 1, 1, 32, {3}, {32}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 1, 1, 64, 3, {3}, {32}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 1, 1, 1, 1, {3}, {32}, {1}, {1}, {1}, {1}}); + this->Run(); } TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D) { this->conv_params.clear(); + this->conv_params.push_back( + {2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back({2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); - // this->conv_params.push_back( - // {2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); - // this->conv_params.push_back({2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, - // 0}}); this->conv_params.push_back({2, 2, 64, 5, 5, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, - // {0, 0}}); this->conv_params.push_back( - // {2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - // this->conv_params.push_back( - // {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); - // this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, - // 1}}); this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, - // {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, - // 1}, {1, 1}}); this->conv_params.push_back( - // {2, 16, 16, 1, 1, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 2, 64, 5, 5, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 16, 16, 1, 1, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}}); this->Run(); } TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D) { - // this->conv_params.clear(); - // this->conv_params.push_back( - // {3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); - // this->conv_params.push_back( - // {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - // this->conv_params.push_back( - // {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); - // this->conv_params.push_back( - // {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - // this->conv_params.push_back( - // {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - // this->conv_params.push_back( - // {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - // this->conv_params.push_back( - // {3, 16, 16, 1, 1, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - // this->Run(); -} + this->conv_params.clear(); + this->conv_params.push_back( + {3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 16, 16, 1, 1, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->Run(); +} \ No newline at end of file