From d132df2bf5929686ec1335f4e9504e78551a1fa3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Tue, 3 Feb 2026 04:10:09 -0500 Subject: [PATCH] Finalize conv specialization for filter 3x3, pad 1, stride 1, dilation 1 case. --- .../run_grouped_conv_fwd_example.inc | 4 +- .../transform_conv_fwd_to_gemm.hpp | 56 +++++++++++++++++-- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc index 4cfd23ea66..5c7be8a2a4 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc @@ -76,7 +76,7 @@ using DeviceConvFwdInstance = InElementOp, WeiElementOp, OutElementOp, - //ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, + //ConvSpec, // ConvForwardSpecialization ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3Stride1Pad1Dilation1_32_4_4_200x200, GemmSpec, // GemmSpecialization 256, // BlockSize @@ -108,7 +108,7 @@ using DeviceConvFwdInstance = S<1, 32, 1, 4>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, must match with num merged groups 1, // Vector load/store size for output tensor = CDEBlockTransferScalarPerVector_NPerBlock ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, + ck::BlockGemmPipelineVersion::v1, InKernelDataType, WeiKernelDataType, false, // No direct load diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index 7202833737..6730d3a01b 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -882,12 +882,12 @@ struct TransformConvFwdToGemm throw std::runtime_error(oss.str()); } - const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + constexpr auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( make_tuple(N, H, W, NumGroupsToMerge, C), make_tuple( - NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + NStride, HiStride, WiStride, GStride, CStride)); - const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( + constexpr auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( in_n_hi_wi_groups_c_desc, make_tuple(make_pass_through_transform(N), make_pad_transform(H, Pad1, Pad1), @@ -899,7 +899,7 @@ struct TransformConvFwdToGemm make_tuple( Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( + constexpr auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( in_n_hip_wip_groups_c_desc, make_tuple(make_pass_through_transform(N), make_embed_transform(make_tuple(Y, H), @@ -916,12 +916,14 @@ struct TransformConvFwdToGemm Sequence<5>{}, Sequence<6>{})); - return transform_tensor_descriptor( + constexpr auto gemmm_gemmn_desc = transform_tensor_descriptor( in_n_y_ho_x_wo_groups_c_desc, make_tuple(make_merge_transform(make_tuple(N, H, W, NumGroupsToMerge)), make_merge_transform(make_tuple(Y, X, C))), make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 6>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); + + return gemmm_gemmn_desc; } else if constexpr(ConvForwardSpecialization == device::ConvolutionForwardSpecialization::Filter1x1Pad0) @@ -1477,6 +1479,50 @@ struct TransformConvFwdToGemm make_tuple(Sequence<0>{}, Sequence<1>{})); } } + else if constexpr (ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3Stride1Pad1Dilation1_32_4_4_200x200 && + NumGroupsToMerge > 1) + { + constexpr ck::index_t C = 4; + constexpr ck::index_t K = 4; + + using FilterSizeNumType = + ck::conditional_t, + ck::conditional_t, Number<27*C>>>; + + constexpr ck::index_t KStrideTensorB = FilterSizeNumType{}; + constexpr ck::index_t GStrideTensorB = Number{}; + constexpr ck::index_t CStrideTensorB = 1; + + // Ensure the strides match with the expected values + if (KStrideTensorB != this->KStrideTensorB_ || + GStrideTensorB != this->GStrideTensorB_ || + CStrideTensorB != this->CStrideTensorB_) + { + std::stringstream oss; + oss << "Error: Stride mismatch in MakeBDescriptor_N_K for special case. " + << "Expected KStrideTensorB: " << KStrideTensorB + << ", Actual: " << this->KStrideTensorB_ << ". " + << "Expected GStrideTensorB: " << GStrideTensorB + << ", Actual: " << this->GStrideTensorB_ << ". " + << "Expected CStrideTensorB: " << CStrideTensorB + << ", Actual: " << this->CStrideTensorB_ << "."; + throw std::runtime_error(oss.str()); + } + + constexpr auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(K, NumGroupsToMerge, FilterSizeNumType{}), + make_tuple(KStrideTensorB, GStrideTensorB, CStrideTensorB)); + constexpr auto gemmn_gemm_k_desc = transform_tensor_descriptor( + wei_gemmn_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(K, NumGroupsToMerge)), + make_pass_through_transform(FilterSizeNumType{})), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return gemmn_gemm_k_desc; + } else { if constexpr(NumGroupsToMerge == 1)