mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Support bf16/f32/f16 and NHWGC conv2d_bwd_data (#757)
* Support bf16/f32/f16 and NHWGC conv2d_bwd_data
* Add interface test
* clang format
* Comment fixes
* Add more friendly error message
[ROCm/composable_kernel commit: 63388e84ab]
This commit is contained in:
@@ -173,6 +173,10 @@
|
||||
|
||||
// workaround: compiler issue on gfx908
|
||||
#define CK_WORKAROUND_SWDEV_388832 1
|
||||
|
||||
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
|
||||
#define CK_WORKAROUND_SWDEV_3318619 0
|
||||
|
||||
// flag to enable (1) or disable (0) the debugging output in some kernels
|
||||
#define DEBUG_LOG 0
|
||||
|
||||
|
||||
@@ -459,7 +459,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_k_wos_lengths[0]},
|
||||
num_gemm_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
@@ -508,9 +507,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
// number of GEMM
|
||||
num_gemm_ = YTilde * XTilde;
|
||||
|
||||
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
|
||||
{
|
||||
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
|
||||
@@ -626,7 +622,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
void Print() const
|
||||
{
|
||||
for(index_t i = 0; i < num_gemm_; i++)
|
||||
for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++)
|
||||
{
|
||||
std::cout << "a_grid_desc_ak0_m_ak1_container_"
|
||||
<< a_grid_desc_ak0_m_ak1_container_[i] << std::endl;
|
||||
@@ -654,7 +650,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
// tensor descriptor for problem definition
|
||||
index_t num_group_;
|
||||
index_t num_gemm_;
|
||||
std::vector<AGridDesc_M_K> a_grid_desc_m_k_container_;
|
||||
std::vector<BGridDesc_N_K> b_grid_desc_n_k_container_;
|
||||
std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
|
||||
@@ -708,7 +703,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
for(index_t i = 0; i < arg.num_gemm_; i++)
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
|
||||
arg.b_grid_desc_n_k_container_[i],
|
||||
@@ -807,7 +802,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
|
||||
// vector load for A matrix from global memory to LDS
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
|
||||
{
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
@@ -862,7 +858,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
|
||||
// vector store for E
|
||||
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC>)
|
||||
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NHWGC>)
|
||||
{
|
||||
// vector store C matrix into global memory
|
||||
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
|
||||
@@ -13,6 +13,61 @@
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
namespace {
|
||||
template <
|
||||
index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization>
|
||||
constexpr auto
|
||||
make_out_n_ho_wo_k_grid_desc(const index_t N,
|
||||
const index_t Ho,
|
||||
const index_t Wo,
|
||||
const index_t K,
|
||||
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides)
|
||||
{
|
||||
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
|
||||
{
|
||||
const index_t NStride = out_g_n_k_wos_strides[1];
|
||||
const index_t HiStride = out_g_n_k_wos_strides[3];
|
||||
const index_t WiStride = out_g_n_k_wos_strides[4];
|
||||
const auto CStride = Number<1>{};
|
||||
if constexpr(ConvBwdDataSpecialization ==
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0)
|
||||
{
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
|
||||
make_tuple(WiStride, CStride));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N, Ho, Wo, K),
|
||||
make_tuple(NStride, HiStride, WiStride, CStride));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
|
||||
{
|
||||
// assume packed
|
||||
if constexpr(ConvBwdDataSpecialization ==
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0)
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <
|
||||
index_t NDimSpatial,
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
|
||||
@@ -29,11 +84,12 @@ struct TransformConvBwdDataToGemm_v1
|
||||
|
||||
template <typename ALayout,
|
||||
typename std::enable_if<NDimSpatial == 2 &&
|
||||
is_same_v<ALayout, tensor_layout::convolution::GNHWK>,
|
||||
(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NHWGK>),
|
||||
bool>::type = false>
|
||||
static auto MakeADescriptor_AK0_M_AK1(
|
||||
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
|
||||
@@ -70,9 +126,9 @@ struct TransformConvBwdDataToGemm_v1
|
||||
|
||||
const index_t AK0 = K / AK1;
|
||||
|
||||
// assume packed
|
||||
const auto out_n_ho_wo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
|
||||
make_out_n_ho_wo_k_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>(
|
||||
N, Ho, Wo, K, out_g_n_k_wos_strides);
|
||||
|
||||
if constexpr(ConvBwdDataSpecialization ==
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
@@ -80,7 +136,7 @@ struct TransformConvBwdDataToGemm_v1
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
out_n_ho_wo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_unmerge_transform(make_tuple(AK0, AK1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
|
||||
Reference in New Issue
Block a user