mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Support broadcast for bias in grouped conv fwd (#1081)
* Support broadcast for bias in grouped conv fwd * Fix comment * Comment fixes * Remove GK layout
This commit is contained in:
@@ -357,15 +357,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
|
||||
// Shape of Ds and E must be aligned. Strides can be different.
|
||||
// Pass e_g_n_k_wos_lengths for logical broadcast.
|
||||
static auto MakeDsGridDescriptor_M_N(
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_g_n_k_wos_lengths[i],
|
||||
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(e_g_n_k_wos_lengths,
|
||||
ds_g_n_k_wos_strides[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
@@ -569,7 +571,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
// D desc
|
||||
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
|
||||
ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]);
|
||||
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]);
|
||||
});
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
|
||||
@@ -916,8 +918,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> ||
|
||||
is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> ||
|
||||
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
|
||||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::GK> ||
|
||||
is_same_v<DLayout, ctc::G_K>)
|
||||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::G_K>)
|
||||
{
|
||||
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
|
||||
|
||||
@@ -925,6 +926,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
valid = false;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<DLayout, ctc::G_K>)
|
||||
{
|
||||
// G and K must be the same
|
||||
if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] ||
|
||||
arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2])
|
||||
{
|
||||
valid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// E and D must have the same shape
|
||||
for(index_t d = 0; d < NDimSpatial + 3; d++)
|
||||
{
|
||||
if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d])
|
||||
{
|
||||
valid = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -631,8 +631,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> ||
|
||||
is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> ||
|
||||
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
|
||||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::GK> ||
|
||||
is_same_v<DLayout, ctc::G_K>)
|
||||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::G_K>)
|
||||
{
|
||||
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
|
||||
|
||||
|
||||
Reference in New Issue
Block a user