diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc index 637ea2fbfb..6792e70ebf 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc @@ -130,7 +130,7 @@ int run_grouped_conv_bwd_weight_example_with_layouts( output_dev_buf.GetDeviceBuffer(), kbatch); - std::cout << "Run Grouped Conv Fwd kernel" << std::endl; + std::cout << "Run Grouped Conv Bwd Weight kernel" << std::endl; std::cout << "input: " << input.mDesc << std::endl; std::cout << "weight: " << weight.mDesc << std::endl; std::cout << "output: " << output.mDesc << std::endl; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index f4cc5d7138..75b1a5c11a 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -384,9 +384,8 @@ struct GroupedConvBwdDataKernelArgs static constexpr index_t MaxGroupedGemmGroupsNum = 128; - using ABCGridDescs = - remove_cvref_t; + using ABCGridDescs = remove_cvref_t< + decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))>; using AGridDescMK = remove_cvref_t{}])>; using BGridDescNK = remove_cvref_t{}])>; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 2700353049..38276ee4e1 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -254,9 +254,8 @@ struct GroupedConvBwdWeightKernelArgs GemmBatch = args.G_; } - using ABCGridDescs = - remove_cvref_t; + using ABCGridDescs = remove_cvref_t< + decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())>; using AGridDescMK = remove_cvref_t{}])>; using BGridDescNK = remove_cvref_t{}])>;