diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index f4cc5d7138..75b1a5c11a 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -384,9 +384,8 @@ struct GroupedConvBwdDataKernelArgs static constexpr index_t MaxGroupedGemmGroupsNum = 128; - using ABCGridDescs = - remove_cvref_t; + using ABCGridDescs = remove_cvref_t< + decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))>; using AGridDescMK = remove_cvref_t{}])>; using BGridDescNK = remove_cvref_t{}])>; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 2700353049..38276ee4e1 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -254,9 +254,8 @@ struct GroupedConvBwdWeightKernelArgs GemmBatch = args.G_; } - using ABCGridDescs = - remove_cvref_t; + using ABCGridDescs = remove_cvref_t< + decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())>; using AGridDescMK = remove_cvref_t{}])>; using BGridDescNK = remove_cvref_t{}])>;