diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index 6af8ac6488..5cc0c70b51 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -91,8 +91,8 @@ inline bool parse_cmd_args(int argc, config.time_kernel = std::stoi(argv[3]); const ck::index_t num_dim_spatial = std::stoi(argv[4]); - conv_params = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + conv_params = + ck::utils::conv::parse_conv_param(num_dim_spatial, num_conv_param_leading_args, argv); } else { diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp index b1554412b1..8d9606822e 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp @@ -11,10 +11,10 @@ using CShuffleDataType = FP16; using DsDataType = ck::Tuple<>; using InDataType = FP16; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using WeiLayout = ck::tensor_layout::convolution::GKYXC; using DsLayout = ck::Tuple<>; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using OutElementOp = PassThrough; using WeiElementOp = PassThrough; diff --git a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc index 25678491ce..2d465922db 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc +++ b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc @@ -90,6 +90,9 @@ bool run_conv_bwd_data(const ExecutionConfig& config, wei_element_op, in_element_op); + DeviceMem workspace_buf(argument.GetWorkspaceSizeBytes()); + conv.SetWorkspacePointer(&argument, workspace_buf.GetDeviceBuffer()); + if(!conv.IsSupportedArgument(argument)) { std::cerr << "wrong! device_conv with the specified compilation parameters does "