mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Support bf16/f32/f16 and NHWGC conv2d_bwd_data (#757)
* Support bf16/f32/f16 and NHWGC conv2d_bwd_data * Add interface test * clang format * Comment fixes * Add more friendly error message
This commit is contained in:
@@ -459,7 +459,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_k_wos_lengths[0]},
|
||||
num_gemm_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
@@ -508,9 +507,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
// number of GEMM
|
||||
num_gemm_ = YTilde * XTilde;
|
||||
|
||||
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
|
||||
{
|
||||
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
|
||||
@@ -626,7 +622,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
void Print() const
|
||||
{
|
||||
for(index_t i = 0; i < num_gemm_; i++)
|
||||
for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++)
|
||||
{
|
||||
std::cout << "a_grid_desc_ak0_m_ak1_container_"
|
||||
<< a_grid_desc_ak0_m_ak1_container_[i] << std::endl;
|
||||
@@ -654,7 +650,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
// tensor descriptor for problem definition
|
||||
index_t num_group_;
|
||||
index_t num_gemm_;
|
||||
std::vector<AGridDesc_M_K> a_grid_desc_m_k_container_;
|
||||
std::vector<BGridDesc_N_K> b_grid_desc_n_k_container_;
|
||||
std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
|
||||
@@ -708,7 +703,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
for(index_t i = 0; i < arg.num_gemm_; i++)
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
|
||||
arg.b_grid_desc_n_k_container_[i],
|
||||
@@ -807,7 +802,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
|
||||
// vector load for A matrix from global memory to LDS
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
|
||||
{
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
@@ -862,7 +858,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
|
||||
// vector store for E
|
||||
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC>)
|
||||
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NHWGC>)
|
||||
{
|
||||
// vector store C matrix into global memory
|
||||
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
|
||||
Reference in New Issue
Block a user