mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
conv+conv (1x1 only) example using gemm+gemm (#393)
* refactor conv
* add conv+conv example, 1x1 only
[ROCm/composable_kernel commit: 4df6d93f60]
This commit is contained in:
@@ -49,30 +49,47 @@ struct ConvParam
|
||||
|
||||
std::size_t GetFlops() const;
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
std::size_t GetByte() const
|
||||
template <typename InDataType>
|
||||
std::size_t GetInputByte() const
|
||||
{
|
||||
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
|
||||
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
|
||||
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
|
||||
return sizeof(InDataType) *
|
||||
(G_ * N_ * C_ *
|
||||
std::accumulate(std::begin(input_spatial_lengths_),
|
||||
std::begin(input_spatial_lengths_) + num_dim_spatial_,
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>())) +
|
||||
sizeof(WeiDataType) *
|
||||
(G_ * K_ * C_ *
|
||||
std::accumulate(std::begin(filter_spatial_lengths_),
|
||||
std::begin(filter_spatial_lengths_) + num_dim_spatial_,
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>())) +
|
||||
sizeof(OutDataType) * (G_ * N_ * K_ *
|
||||
(G_ * N_ * C_ *
|
||||
std::accumulate(std::begin(input_spatial_lengths_),
|
||||
std::begin(input_spatial_lengths_) + num_dim_spatial_,
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()));
|
||||
}
|
||||
|
||||
template <typename WeiDataType>
|
||||
std::size_t GetWeightByte() const
|
||||
{
|
||||
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
|
||||
return sizeof(WeiDataType) *
|
||||
(G_ * K_ * C_ *
|
||||
std::accumulate(std::begin(filter_spatial_lengths_),
|
||||
std::begin(filter_spatial_lengths_) + num_dim_spatial_,
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()));
|
||||
}
|
||||
|
||||
template <typename OutDataType>
|
||||
std::size_t GetOutputByte() const
|
||||
{
|
||||
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
|
||||
return sizeof(OutDataType) * (G_ * N_ * K_ *
|
||||
std::accumulate(std::begin(output_spatial_lengths_),
|
||||
std::end(output_spatial_lengths_),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()));
|
||||
}
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
std::size_t GetByte() const
|
||||
{
|
||||
return GetInputByte<InDataType>() + GetWeightByte<WeiDataType>() +
|
||||
GetOutputByte<OutDataType>();
|
||||
}
|
||||
};
|
||||
|
||||
std::string get_conv_param_parser_helper_msg();
|
||||
|
||||
Reference in New Issue
Block a user