mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Conv bwd splitN support (#3047)
* Conv bwd splitN support * Adjust splitting calculations to lengths format * Prepare indexing for future splitK support
This commit is contained in:
@@ -27,7 +27,7 @@ struct TransformConvBwdDataToGemm
|
||||
static constexpr auto I3 = number<3>{};
|
||||
static constexpr auto I4 = number<4>{};
|
||||
static constexpr auto I5 = number<5>{};
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
|
||||
template <typename ConvDimsType>
|
||||
static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
|
||||
const ConvDimsType& strides,
|
||||
@@ -44,25 +44,45 @@ struct TransformConvBwdDataToGemm
|
||||
}
|
||||
|
||||
template <typename ConvDimsType>
|
||||
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
|
||||
const ConvDimsType& a_g_n_c_wis_strides,
|
||||
const ConvDimsType& c_g_n_k_wos_lengths,
|
||||
const ConvDimsType& c_g_n_k_wos_strides)
|
||||
static IndexType GetSplitedNSize(const ConvDimsType& c_g_n_k_wos_lengths,
|
||||
const ConvDimsType& a_g_n_c_wis_lengths)
|
||||
{
|
||||
|
||||
// Calculate strides internally assuming contiguous memory layout
|
||||
ConvDimsType c_g_n_k_wos_strides, a_g_n_c_wis_strides;
|
||||
const index_t num_dims = c_g_n_k_wos_strides.size();
|
||||
|
||||
// Calculate strides for input tensor (innermost to outermost),
|
||||
// Don't include outermost dimension G since it's gemm batch.
|
||||
a_g_n_c_wis_strides[num_dims - 1] = 1;
|
||||
for(index_t i = num_dims - 2; i >= 1; i--)
|
||||
{
|
||||
a_g_n_c_wis_strides[i] = a_g_n_c_wis_strides[i + 1] * a_g_n_c_wis_lengths[i + 1];
|
||||
}
|
||||
|
||||
// Calculate strides for output tensor,
|
||||
// Don't include outermost dimension G since it's gemm batch.
|
||||
c_g_n_k_wos_strides[num_dims - 1] = 1;
|
||||
for(index_t i = num_dims - 2; i >= 1; i--)
|
||||
{
|
||||
c_g_n_k_wos_strides[i] = c_g_n_k_wos_strides[i + 1] * c_g_n_k_wos_lengths[i + 1];
|
||||
}
|
||||
|
||||
const long_index_t a_element_space_size =
|
||||
calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
|
||||
const long_index_t c_element_space_size =
|
||||
calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
|
||||
const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
|
||||
c_element_space_size * sizeof(CDataType));
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
const long_index_t element_space_size = ck_tile::max(
|
||||
a_element_space_size * sizeof(ADataType), c_element_space_size * sizeof(CDataType));
|
||||
|
||||
const IndexType N = a_g_n_c_wis_lengths[I1];
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
|
||||
const IndexType N = c_g_n_k_wos_lengths[I1];
|
||||
|
||||
if(element_space_size > TwoGB)
|
||||
{
|
||||
// Minimum divisor of N to not exceed 2GB
|
||||
const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
|
||||
const auto divisor = ck_tile::integer_divide_ceil(element_space_size, TwoGB);
|
||||
|
||||
if(divisor <= static_cast<double>(N))
|
||||
{
|
||||
@@ -93,9 +113,12 @@ struct TransformConvBwdDataToGemm
|
||||
return N;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
public:
|
||||
// Public getter methods for Split-N support
|
||||
CK_TILE_HOST constexpr IndexType GetN() const { return N_; }
|
||||
CK_TILE_HOST constexpr IndexType GetOriginalN() const { return original_N_; }
|
||||
|
||||
CK_TILE_HOST constexpr TransformConvBwdDataToGemm() {}
|
||||
|
||||
template <typename TransformConvBwdDataToGemmBase>
|
||||
@@ -103,6 +126,7 @@ struct TransformConvBwdDataToGemm
|
||||
TransformConvBwdDataToGemm(const TransformConvBwdDataToGemmBase& transform_conv_to_gemm_base)
|
||||
: G_{static_cast<IndexType>(transform_conv_to_gemm_base.G_)},
|
||||
N_{static_cast<IndexType>(transform_conv_to_gemm_base.N_)},
|
||||
original_N_{static_cast<IndexType>(transform_conv_to_gemm_base.original_N_)},
|
||||
Di_{static_cast<IndexType>(transform_conv_to_gemm_base.Di_)},
|
||||
Hi_{static_cast<IndexType>(transform_conv_to_gemm_base.Hi_)},
|
||||
Wi_{static_cast<IndexType>(transform_conv_to_gemm_base.Wi_)},
|
||||
@@ -170,17 +194,18 @@ struct TransformConvBwdDataToGemm
|
||||
IdxYTilde_{I1},
|
||||
IdxXTilde_{tildes[I0]}
|
||||
{
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
|
||||
// Store original N
|
||||
original_N_ = a_g_n_c_wis_lengths[I1];
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
N_ = GetSplitedNSize(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
|
||||
N_ = GetSplitedNSize(c_g_n_k_wos_lengths, a_g_n_c_wis_lengths);
|
||||
}
|
||||
else
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
N_ = a_g_n_c_wis_lengths[I1];
|
||||
}
|
||||
#endif
|
||||
|
||||
GcdStrideDilationW_ = gcd(ConvStrideW_, ConvDilationW_);
|
||||
XTilde_ = ConvStrideW_ / GcdStrideDilationW_;
|
||||
@@ -229,17 +254,19 @@ struct TransformConvBwdDataToGemm
|
||||
IdxYTilde_{tildes[I0]},
|
||||
IdxXTilde_{tildes[I1]}
|
||||
{
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
|
||||
// Store original N
|
||||
original_N_ = a_g_n_c_wis_lengths[I1];
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
N_ = GetSplitedNSize(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
|
||||
N_ = GetSplitedNSize(c_g_n_k_wos_lengths, a_g_n_c_wis_lengths);
|
||||
}
|
||||
else
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
N_ = a_g_n_c_wis_lengths[I1];
|
||||
}
|
||||
#endif
|
||||
|
||||
GcdStrideDilationW_ = gcd(ConvStrideW_, ConvDilationW_);
|
||||
GcdStrideDilationH_ = gcd(ConvStrideH_, ConvDilationH_);
|
||||
XTilde_ = ConvStrideW_ / GcdStrideDilationW_;
|
||||
@@ -291,17 +318,19 @@ struct TransformConvBwdDataToGemm
|
||||
IdxYTilde_{tildes[I1]},
|
||||
IdxXTilde_{tildes[I2]}
|
||||
{
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
|
||||
// Store original N
|
||||
original_N_ = a_g_n_c_wis_lengths[I1];
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
N_ = GetSplitedNSize(
|
||||
a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides);
|
||||
N_ = GetSplitedNSize(c_g_n_k_wos_lengths, a_g_n_c_wis_lengths);
|
||||
}
|
||||
else
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
N_ = a_g_n_c_wis_lengths[I1];
|
||||
}
|
||||
#endif
|
||||
|
||||
GcdStrideDilationW_ = gcd(ConvStrideW_, ConvDilationW_);
|
||||
GcdStrideDilationH_ = gcd(ConvStrideH_, ConvDilationH_);
|
||||
GcdStrideDilationD_ = gcd(ConvStrideD_, ConvDilationD_);
|
||||
@@ -1068,7 +1097,7 @@ struct TransformConvBwdDataToGemm
|
||||
in_gemmmraw_gemmnraw_grid_desc);
|
||||
}
|
||||
|
||||
IndexType G_, N_;
|
||||
IndexType G_, N_, original_N_;
|
||||
IndexType Di_, Hi_, Wi_;
|
||||
IndexType Do_, Ho_, Wo_;
|
||||
IndexType Z_, Y_, X_;
|
||||
|
||||
Reference in New Issue
Block a user