mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Code clean-up.
This commit is contained in:
@@ -88,8 +88,6 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
//std::min(static_cast<index_t>(args.G_), GroupedConvTraitsType_::NumGroupsToMerge);
|
||||
|
||||
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NWGK
|
||||
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NWGC
|
||||
group_stride_c = args.K_ * args.C_ // C: Wei GKXC
|
||||
@@ -103,14 +101,12 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmN = b_grid_desc_n_k.get_length(number<0>{});
|
||||
GemmK = a_grid_desc_m_k.get_length(number<1>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
ZYX = conv_to_gemm_transformer.ZYX_;
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch
|
||||
<< ", ZYX: " << ZYX << std::endl;
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,8 +173,6 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
//std::min(static_cast<index_t>(args.G_), GroupedConvTraitsType_::NumGroupsToMerge);
|
||||
|
||||
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NHWGK
|
||||
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NHWGC
|
||||
group_stride_c = args.K_ * args.C_ // C: Wei GKYXC
|
||||
@@ -192,14 +186,12 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmN = b_grid_desc_n_k.get_length(number<0>{});
|
||||
GemmK = a_grid_desc_m_k.get_length(number<1>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
ZYX = conv_to_gemm_transformer.ZYX_;
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch
|
||||
<< ", ZYX: " << ZYX << std::endl;
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,8 +265,6 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
//std::min(static_cast<index_t>(args.G_), GroupedConvTraitsType_::NumGroupsToMerge);
|
||||
|
||||
group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NDHWGK
|
||||
group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NDHWGC
|
||||
group_stride_c = args.K_ * args.C_ // C: Wei GKZYXC
|
||||
@@ -288,14 +278,12 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmN = b_grid_desc_n_k.get_length(number<0>{});
|
||||
GemmK = a_grid_desc_m_k.get_length(number<1>{});
|
||||
GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch);
|
||||
ZYX = conv_to_gemm_transformer.ZYX_;
|
||||
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
|
||||
<< ", GemmBatch: " << GemmBatch
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch
|
||||
<< ", ZYX: " << ZYX << std::endl;
|
||||
<< ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -322,7 +310,6 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
index_t GemmK;
|
||||
index_t GemmBatch;
|
||||
index_t NumGroupsPerBatch;
|
||||
index_t ZYX;
|
||||
|
||||
const void* out_ptr;
|
||||
const void* in_ptr;
|
||||
|
||||
@@ -125,8 +125,7 @@ struct TransformConvBwdWeightToGemm
|
||||
InLeftPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadW_)},
|
||||
InRightPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadD_)},
|
||||
InRightPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadH_)},
|
||||
InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)},
|
||||
ZYX_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ZYX_)}
|
||||
InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -164,8 +163,7 @@ struct TransformConvBwdWeightToGemm
|
||||
InLeftPadW_{input_left_pads[I0]},
|
||||
InRightPadD_{I0},
|
||||
InRightPadH_{I0},
|
||||
InRightPadW_{input_right_pads[I0]},
|
||||
ZYX_{X_}
|
||||
InRightPadW_{input_right_pads[I0]}
|
||||
{
|
||||
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
@@ -219,8 +217,7 @@ struct TransformConvBwdWeightToGemm
|
||||
InLeftPadW_{input_left_pads[I1]},
|
||||
InRightPadD_{I0},
|
||||
InRightPadH_{input_right_pads[I0]},
|
||||
InRightPadW_{input_right_pads[I1]},
|
||||
ZYX_{Y_ * X_}
|
||||
InRightPadW_{input_right_pads[I1]}
|
||||
{
|
||||
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
@@ -274,8 +271,7 @@ struct TransformConvBwdWeightToGemm
|
||||
InLeftPadW_{input_left_pads[I2]},
|
||||
InRightPadD_{input_right_pads[I0]},
|
||||
InRightPadH_{input_right_pads[I1]},
|
||||
InRightPadW_{input_right_pads[I2]},
|
||||
ZYX_{Z_ * Y_ * X_}
|
||||
InRightPadW_{input_right_pads[I2]}
|
||||
{
|
||||
static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
@@ -413,9 +409,6 @@ struct TransformConvBwdWeightToGemm
|
||||
}
|
||||
#endif
|
||||
|
||||
//////////////////
|
||||
// 1D
|
||||
//////////////////
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 1, bool>::type = false>
|
||||
CK_TILE_HOST auto make_out_grid_desc() const
|
||||
{
|
||||
@@ -529,9 +522,6 @@ struct TransformConvBwdWeightToGemm
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////
|
||||
// 2D
|
||||
//////////////////
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 2, bool>::type = false>
|
||||
CK_TILE_HOST auto make_out_grid_desc() const
|
||||
{
|
||||
@@ -646,9 +636,6 @@ struct TransformConvBwdWeightToGemm
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////
|
||||
// 3D
|
||||
//////////////////
|
||||
template <index_t NDim = NDimSpatial, typename std::enable_if<NDim == 3, bool>::type = false>
|
||||
CK_TILE_HOST auto make_out_grid_desc() const
|
||||
{
|
||||
@@ -1075,7 +1062,6 @@ struct TransformConvBwdWeightToGemm
|
||||
IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_;
|
||||
IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
|
||||
IndexType InRightPadD_, InRightPadH_, InRightPadW_;
|
||||
IndexType ZYX_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user