diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 9dabb3a273..abc474ef4b 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -88,8 +88,6 @@ struct GroupedConvBwdWeightKernelArgs c_grid_desc_m_n = grid_descs.at(number<2>{}); NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge; - //std::min(static_cast(args.G_), GroupedConvTraitsType_::NumGroupsToMerge); - group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NWGK group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NWGC group_stride_c = args.K_ * args.C_ // C: Wei GKXC @@ -103,14 +101,12 @@ struct GroupedConvBwdWeightKernelArgs GemmN = b_grid_desc_n_k.get_length(number<0>{}); GemmK = a_grid_desc_m_k.get_length(number<1>{}); GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch); - ZYX = conv_to_gemm_transformer.ZYX_; if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK << ", GemmBatch: " << GemmBatch - << ", NumGroupsPerBatch: " << NumGroupsPerBatch - << ", ZYX: " << ZYX << std::endl; + << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl; } } @@ -177,8 +173,6 @@ struct GroupedConvBwdWeightKernelArgs c_grid_desc_m_n = grid_descs.at(number<2>{}); NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge; - //std::min(static_cast(args.G_), GroupedConvTraitsType_::NumGroupsToMerge); - group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NHWGK group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NHWGC group_stride_c = args.K_ * args.C_ // C: Wei GKYXC @@ -192,14 +186,12 @@ struct GroupedConvBwdWeightKernelArgs GemmN = b_grid_desc_n_k.get_length(number<0>{}); GemmK = a_grid_desc_m_k.get_length(number<1>{}); GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch); - ZYX = conv_to_gemm_transformer.ZYX_; if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK << ", GemmBatch: " << GemmBatch - << ", NumGroupsPerBatch: " << NumGroupsPerBatch - << ", ZYX: " << ZYX << std::endl; + << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl; } } @@ -273,8 +265,6 @@ struct GroupedConvBwdWeightKernelArgs c_grid_desc_m_n = grid_descs.at(number<2>{}); NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge; - //std::min(static_cast(args.G_), GroupedConvTraitsType_::NumGroupsToMerge); - group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NDHWGK group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NDHWGC group_stride_c = args.K_ * args.C_ // C: Wei GKZYXC @@ -288,14 +278,12 @@ struct GroupedConvBwdWeightKernelArgs GemmN = b_grid_desc_n_k.get_length(number<0>{}); GemmK = a_grid_desc_m_k.get_length(number<1>{}); GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch); - ZYX = conv_to_gemm_transformer.ZYX_; if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK << ", GemmBatch: " << GemmBatch - << ", NumGroupsPerBatch: " << NumGroupsPerBatch - << ", ZYX: " << ZYX << std::endl; + << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl; } } @@ -322,7 +310,6 @@ struct GroupedConvBwdWeightKernelArgs index_t GemmK; index_t GemmBatch; index_t NumGroupsPerBatch; - index_t ZYX; const void* out_ptr; const void* in_ptr; diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp index 2cc815d552..12e0979458 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp @@ -125,8 +125,7 @@ struct TransformConvBwdWeightToGemm InLeftPadW_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadW_)}, InRightPadD_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadD_)}, InRightPadH_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadH_)}, - InRightPadW_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadW_)}, - ZYX_{static_cast(transform_conv_fwd_to_gemm_base.ZYX_)} + InRightPadW_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadW_)} { } @@ -164,8 +163,7 @@ struct TransformConvBwdWeightToGemm InLeftPadW_{input_left_pads[I0]}, InRightPadD_{I0}, InRightPadH_{I0}, - InRightPadW_{input_right_pads[I0]}, - ZYX_{X_} + InRightPadW_{input_right_pads[I0]} { static_assert(std::is_same_v> || std::is_same_v>); @@ -219,8 +217,7 @@ struct TransformConvBwdWeightToGemm InLeftPadW_{input_left_pads[I1]}, InRightPadD_{I0}, InRightPadH_{input_right_pads[I0]}, - InRightPadW_{input_right_pads[I1]}, - ZYX_{Y_ * X_} + InRightPadW_{input_right_pads[I1]} { static_assert(std::is_same_v> || std::is_same_v>); @@ -274,8 +271,7 @@ struct TransformConvBwdWeightToGemm InLeftPadW_{input_left_pads[I2]}, InRightPadD_{input_right_pads[I0]}, InRightPadH_{input_right_pads[I1]}, - InRightPadW_{input_right_pads[I2]}, - ZYX_{Z_ * Y_ * X_} + InRightPadW_{input_right_pads[I2]} { static_assert(std::is_same_v> || std::is_same_v>); @@ -413,9 +409,6 @@ struct TransformConvBwdWeightToGemm } #endif - ////////////////// - // 1D - ////////////////// template ::type = false> CK_TILE_HOST auto make_out_grid_desc() const { @@ -529,9 +522,6 @@ struct TransformConvBwdWeightToGemm } } - ////////////////// - // 2D - ////////////////// template ::type = false> CK_TILE_HOST auto make_out_grid_desc() const { @@ -646,9 +636,6 @@ struct TransformConvBwdWeightToGemm } } - ////////////////// - // 3D - ////////////////// template ::type = false> CK_TILE_HOST auto make_out_grid_desc() const { @@ -1075,7 +1062,6 @@ struct TransformConvBwdWeightToGemm IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_; IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_; IndexType InRightPadD_, InRightPadH_, InRightPadW_; - IndexType ZYX_; }; } // namespace ck_tile