mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Handle split-K autodeduction in explicit gemm conv.
This commit is contained in:
@@ -118,7 +118,8 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
p_wei_grid_{p_wei_grid}
|
||||
p_wei_grid_{p_wei_grid},
|
||||
split_k_{split_k}
|
||||
{
|
||||
constexpr index_t spatial_offset = 3;
|
||||
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
|
||||
@@ -176,7 +177,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
out_element_op,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
split_k};
|
||||
split_k_};
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -199,7 +200,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
out_element_op,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
split_k};
|
||||
split_k_};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,6 +237,7 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
bool is_filter_data_packed;
|
||||
CElementwiseGridDesc elementwise_desc_;
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_;
|
||||
const ck::index_t split_k_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -301,6 +303,12 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if (arg.split_k_ < 0)
|
||||
{
|
||||
// TODO: Add split-K autodeduction.
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())
|
||||
|
||||
Reference in New Issue
Block a user