diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp index a819b91b05..01c3276bf7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp @@ -118,7 +118,8 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads}, - p_wei_grid_{p_wei_grid} + p_wei_grid_{p_wei_grid}, + split_k_{split_k} { constexpr index_t spatial_offset = 3; const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset, @@ -176,7 +177,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl out_element_op, in_element_op, wei_element_op, - split_k}; + split_k_}; } else { @@ -199,7 +200,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl out_element_op, in_element_op, wei_element_op, - split_k}; + split_k_}; } } @@ -236,6 +237,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl bool is_filter_data_packed; CElementwiseGridDesc elementwise_desc_; Block2TileMapElementwise elementwise_block_2_ctile_map_; + const ck::index_t split_k_; }; // Invoker @@ -301,6 +303,12 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl static bool IsSupportedArgument(const Argument& arg) { + if (arg.split_k_ < 0) + { + // TODO: Add split-K autodeduction. + return false; + } + if constexpr(NDimSpatial == 2) { if constexpr(!is_NHWGC_GKYXC_NHWGK())