Handle split-K autodeduction in explicit gemm conv.

This commit is contained in:
Ville Pietilä
2025-07-10 14:11:20 +00:00
parent 611d6ac82d
commit 08ebf5ecc2

View File

@@ -118,7 +118,8 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads},
p_wei_grid_{p_wei_grid}
p_wei_grid_{p_wei_grid},
split_k_{split_k}
{
constexpr index_t spatial_offset = 3;
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
@@ -176,7 +177,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
out_element_op,
in_element_op,
wei_element_op,
split_k};
split_k_};
}
else
{
@@ -199,7 +200,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
out_element_op,
in_element_op,
wei_element_op,
split_k};
split_k_};
}
}
@@ -236,6 +237,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
bool is_filter_data_packed;
CElementwiseGridDesc elementwise_desc_;
Block2TileMapElementwise elementwise_block_2_ctile_map_;
const ck::index_t split_k_;
};
// Invoker
@@ -301,6 +303,12 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
static bool IsSupportedArgument(const Argument& arg)
{
if (arg.split_k_ < 0)
{
// TODO: Add split-K autodeduction.
return false;
}
if constexpr(NDimSpatial == 2)
{
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())