From 70f51bb03f7de3c2d980d99344cf425cff5703f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 10 May 2024 10:57:42 +0200 Subject: [PATCH] Change output gemm type to AccDataType in two stage conv bwd wei (#1283) [ROCm/composable_kernel commit: 8346af9c686649703904d3c8c5d81e89c4116d4c] --- ...onv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 25 ++++++++++++------- ..._conv_bwd_weight_xdl_bilinear_instance.hpp | 1 + ...t_grouped_conv_bwd_weight_xdl_bilinear.cpp | 2 ++ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index a5ae0565f3..3c33c7dbc1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -197,6 +197,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle K0PerBlock, ConvBackwardWeightSpecialization>{}; + static constexpr index_t MaxScalarPerVectorFP32 = 4; + static constexpr index_t WorkspaceInOutScalarPerVector = + is_same_v + ? math::min(CBlockTransferScalarPerVector_NWaveNPerXdl, MaxScalarPerVectorFP32) + : CBlockTransferScalarPerVector_NWaveNPerXdl; + // Bytes per 32 lds bank: 32 * 4 bytes static constexpr auto BankLength = 128; static constexpr auto ElePerBank = BankLength / sizeof(ADataType); @@ -297,7 +303,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ADataType, BDataType, AccDataType, - EDataType, + AccDataType, InMemoryDataOperationEnum::AtomicAdd, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, @@ -337,7 +343,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle BBlockLdsN1Padding, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, - CBlockTransferScalarPerVector_NWaveNPerXdl, + WorkspaceInOutScalarPerVector, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, true, @@ -349,7 +355,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle static constexpr auto MakeElementwiseInputSequence() { return generate_sequence_v2( - [&](auto) constexpr { return Number{}; }, + [&](auto) constexpr { return Number{}; }, Number{}); } @@ -499,7 +505,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {})); using CDGridDesc_M_N = decltype(concat_tuple(Tuple{}, DsGridDesc_M_N{})); using DsGridPointerTuple = decltype(GetDsGridPointerTuple()); - using CDDataTypes = decltype(concat_tuple(Tuple{}, DsGridPointerTuple{})); + using CDDataTypes = decltype(concat_tuple(Tuple{}, DsGridPointerTuple{})); using EGridDesc_M_N = CGridDesc_M_N; static constexpr index_t ClusterLengthMPerBlock = CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); @@ -659,7 +665,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle std::size_t GetWorkspaceSizeBytes() const { - return sizeof(EDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; + return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; } const ADataType* p_a_grid_; @@ -738,7 +744,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); auto launch_gemm_kernel = [&](auto has_main_k_block_loop) { - EDataType* p_c_grid = type_convert(arg.p_workspace_); + AccDataType* p_c_grid = type_convert(arg.p_workspace_); const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_; @@ -753,7 +759,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle GridwiseGemm, ADataType, BDataType, - EDataType, + AccDataType, OutElementwiseOperation, InElementwiseOperation, element_wise::PassThrough, @@ -786,7 +792,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle }; auto launch_elementwise_kernel = [&]() { - const EDataType* p_c_grid = type_convert(arg.p_workspace_); + const AccDataType* p_c_grid = type_convert(arg.p_workspace_); const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_; @@ -907,7 +913,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle } // vector store C matrix into global memory - if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0 && + arg.Conv_C_ % WorkspaceInOutScalarPerVector == 0)) { return false; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp index dfd3216441..8b830d91d5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_bilinear_instance.hpp @@ -86,6 +86,7 @@ using device_grouped_conv_bwd_weight_xdl_c_shuffle_f16_bilinear_instances = std: //#########################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | // generic instance + DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 1, 4, true, 1, 1, S<1, 16, 1, 4>, 1>, DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, Tuple, F16, F16, F16, F32, Tuple, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, S<1, 4, 8, 2>, S<0, 3, 1, 2>, S<0, 2, 1, 3>, 2, 2, 4, true, 1, 1, S<1, 16, 1, 4>, 2>, // instance for small conv.K // for fp16 conv.K and conv.C must be divisible by 2 diff --git a/test/grouped_convnd_bwd_weight/test_grouped_conv_bwd_weight_xdl_bilinear.cpp b/test/grouped_convnd_bwd_weight/test_grouped_conv_bwd_weight_xdl_bilinear.cpp index d733325a98..11748d4717 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_conv_bwd_weight_xdl_bilinear.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_conv_bwd_weight_xdl_bilinear.cpp @@ -264,5 +264,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D) {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 4, 4, {3, 3, 3}, {14, 28, 28}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->Run(); }