diff --git a/example/ck_tile/19_grouped_convolution/CMakeLists.txt b/example/ck_tile/20_grouped_convolution/CMakeLists.txt similarity index 100% rename from example/ck_tile/19_grouped_convolution/CMakeLists.txt rename to example/ck_tile/20_grouped_convolution/CMakeLists.txt diff --git a/example/ck_tile/19_grouped_convolution/grouped_convolution_forward.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp similarity index 94% rename from example/ck_tile/19_grouped_convolution/grouped_convolution_forward.cpp rename to example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp index ee72bb628c..974e0d709b 100644 --- a/example/ck_tile/19_grouped_convolution/grouped_convolution_forward.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp @@ -19,7 +19,10 @@ template + typename OutLayout, + typename DsDataType = ck_tile::tuple<>, + typename DsLayout = ck_tile::tuple<>, + typename CDEElementWise = ck_tile::element_wise::PassThrough> float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s) { constexpr int kBlockPerCu = 1; @@ -49,7 +52,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile:: constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using TilePartitioner = ck_tile::GemmTile1DPartitioner; using GroupedConvTraitsType = - ck_tile::GroupedConvTraits; + ck_tile::GroupedConvTraits; using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; + static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; template < typename InLay = typename GroupedConvTraitsType::InLayout, @@ -61,6 +62,10 @@ struct GroupedConvFwdKernelArgs in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } out_ptr = args.out_ptr; ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, @@ -133,6 +138,10 @@ struct GroupedConvFwdKernelArgs in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } out_ptr = args.out_ptr; ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, @@ -214,6 +223,10 @@ struct GroupedConvFwdKernelArgs in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } out_ptr = args.out_ptr; ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, @@ -270,6 +283,7 @@ struct GroupedConvFwdKernelArgs const void* in_ptr; const void* wei_ptr; + std::array ds_ptr; void* out_ptr; AGridDescMK a_grid_desc_m_k; @@ -338,11 +352,17 @@ struct GroupedConvolutionForwardKernel using InLayout = remove_cvref_t; using WeiLayout = remove_cvref_t; using OutLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + using GemmDsLayout = remove_cvref_t; + + static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; using InDataType = remove_cvref_t; using WeiDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; // Below type is actually accumulation data type - the output of block GEMM. using OutDataType = remove_cvref_t; @@ -517,10 +537,12 @@ struct GroupedConvolutionForwardKernel } template - CK_TILE_DEVICE static auto MakeGemmTensorViews(const InDataType* a_ptr, - const WeiDataType* b_ptr, - OutDataType* c_ptr, - const GroupedConvFwdKernelArgsSpecialized& kargs) + CK_TILE_DEVICE static auto + MakeGemmTensorViews(const InDataType* a_ptr, + const WeiDataType* b_ptr, + const std::array& ds_ptr, + OutDataType* c_ptr, + const GroupedConvFwdKernelArgsSpecialized& kargs) { static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!"); @@ -537,7 +559,21 @@ struct GroupedConvolutionForwardKernel return make_tensor_view(c_ptr, kargs.c_grid_desc_m_n); }(); - return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view); + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + static_assert(std::is_same_v, OutLayout>, + "Not supported!"); + static_assert(std::is_same_v, + "Not supported!"); + static_assert(std::is_same_v, OutDataType>, + "Not supported!"); + + return make_tensor_view( + static_cast(ds_ptr[i]), kargs.c_grid_desc_m_n); + }, + number{}); + + return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view); } template @@ -559,24 +595,35 @@ struct GroupedConvolutionForwardKernel sequence{}); }(); + const auto& ds_tensor_view = views.at(I2); + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + }, + number{}); + const auto& c_pad_view = [&]() { - const auto& c_tensor_view = views.at(I2); + const auto& c_tensor_view = views.at(I3); return pad_tensor_view(c_tensor_view, make_tuple(number{}, number{}), sequence{}); }(); - return make_tuple(a_pad_view, b_pad_view, c_pad_view); + return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view); } template CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) { - const auto& a_pad_view = views.at(I0); - const auto& b_pad_view = views.at(I1); - const auto& c_pad_view = views.at(I2); + const auto& a_pad_view = views.at(I0); + const auto& b_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& c_pad_view = views.at(I3); const auto& a_block_window = [&]() { return make_tile_window(a_pad_view, @@ -592,12 +639,21 @@ struct GroupedConvolutionForwardKernel {i_n, 0}); }(); + const auto ds_block_window = generate_tuple( + [&](auto i) { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_n}); + }, + number{}); + auto c_block_window = make_tile_window( c_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); - return make_tuple(a_block_window, b_block_window, c_block_window); + return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window); } /** @@ -614,6 +670,7 @@ struct GroupedConvolutionForwardKernel */ CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr, const WeiDataType* b_ptr, + const std::array& ds_ptr, OutDataType* c_ptr, void* smem_ptr_0, const GroupedConvFwdKernelArgsSpecialized& kargs, @@ -622,7 +679,8 @@ struct GroupedConvolutionForwardKernel { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs); + MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -633,15 +691,16 @@ struct GroupedConvolutionForwardKernel // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0); // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I2); + auto& c_block_window = gemm_tile_windows.at(I3); EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, smem_ptr_0); + c_block_window, c_block_tile, d_block_window, smem_ptr_0); } /** @@ -661,6 +720,7 @@ struct GroupedConvolutionForwardKernel */ CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr, const WeiDataType* b_ptr, + const std::array& ds_ptr, OutDataType* c_ptr, void* __restrict__ smem_ptr_0, void* __restrict__ smem_ptr_1, @@ -670,7 +730,8 @@ struct GroupedConvolutionForwardKernel { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs); + MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -680,15 +741,16 @@ struct GroupedConvolutionForwardKernel // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I2); + auto& c_block_window = gemm_tile_windows.at(I3); EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, smem_ptr_0); + c_block_window, c_block_tile, d_block_window, smem_ptr_0, smem_ptr_1); } CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const @@ -719,7 +781,8 @@ struct GroupedConvolutionForwardKernel EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) { - RunGemm2LDS(a_ptr, b_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n); + RunGemm2LDS( + a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n); } } else @@ -728,7 +791,7 @@ struct GroupedConvolutionForwardKernel EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) { - RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n); + RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n); } } } diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index fad7c15880..4b7cb3c895 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -20,11 +20,13 @@ struct GroupedConvHostArgs : public conv::ConvParam CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param, const void* in_ptr_, const void* wei_ptr_, + const std::vector ds_ptr_, void* out_ptr_, index_t k_batch_) : conv::ConvParam(conv_param), in_ptr(in_ptr_), wei_ptr(wei_ptr_), + ds_ptr(ds_ptr_), out_ptr(out_ptr_), k_batch(k_batch_) { @@ -32,6 +34,7 @@ struct GroupedConvHostArgs : public conv::ConvParam const void* in_ptr; const void* wei_ptr; + const std::vector ds_ptr; void* out_ptr; index_t k_batch; }; @@ -40,13 +43,23 @@ template struct GroupedConvTraits { + private: + static constexpr auto generate_implicit_gemm_layout() + { + return generate_tuple([](auto) { return ck_tile::tensor_layout::gemm::RowMajor{}; }, + number{}); + } + + public: static constexpr index_t NDimSpatial = NDimSpatial_; static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_; using InLayout = InLayout_; using WeiLayout = WeiLayout_; + using DsLayout = DsLayout_; using OutLayout = OutLayout_; using GroupedConvImplicitGemmTraits = TileGemmTraits; + static constexpr index_t NumDTensor = DsLayout::size(); + using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout()); }; } // namespace ck_tile