mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
Simplify static_cast if-lands (#1828)
This commit is contained in:
@@ -28,14 +28,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(std::is_same_v<ComputeDataType, F8> || std::is_same_v<ComputeDataType, F16> ||
|
||||
std::is_same_v<ComputeDataType, BF16> ||
|
||||
std::is_same_v<ComputeDataType, F32> || std::is_same_v<ComputeDataType, I8> ||
|
||||
std::is_same_v<ComputeDataType, I32> || std::is_same_v<ComputeDataType, int>,
|
||||
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
|
||||
|
||||
double compute_error = 0;
|
||||
if constexpr(std::is_same_v<ComputeDataType, I8> || std::is_same_v<ComputeDataType, I32> ||
|
||||
std::is_same_v<ComputeDataType, int>)
|
||||
if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
@@ -44,14 +41,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
|
||||
}
|
||||
|
||||
static_assert(std::is_same_v<OutDataType, F8> || std::is_same_v<OutDataType, F16> ||
|
||||
std::is_same_v<OutDataType, BF16> || std::is_same_v<OutDataType, F32> ||
|
||||
std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
|
||||
std::is_same_v<OutDataType, int>,
|
||||
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled OutDataType for setting up the relative threshold!");
|
||||
|
||||
double output_error = 0;
|
||||
if constexpr(std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
|
||||
std::is_same_v<OutDataType, int>)
|
||||
if constexpr(is_any_of<OutDataType, I8, I32, int>::value)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
@@ -61,14 +55,11 @@ double get_relative_threshold(const int number_of_accumulations = 1)
|
||||
}
|
||||
double midway_error = std::max(compute_error, output_error);
|
||||
|
||||
static_assert(std::is_same_v<AccDataType, F8> || std::is_same_v<AccDataType, F16> ||
|
||||
std::is_same_v<AccDataType, BF16> || std::is_same_v<AccDataType, F32> ||
|
||||
std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
|
||||
std::is_same_v<AccDataType, int>,
|
||||
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled AccDataType for setting up the relative threshold!");
|
||||
|
||||
double acc_error = 0;
|
||||
if constexpr(std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
|
||||
std::is_same_v<AccDataType, int>)
|
||||
if constexpr(is_any_of<AccDataType, I8, I32, int>::value)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
@@ -89,15 +80,12 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
static_assert(std::is_same_v<ComputeDataType, F8> || std::is_same_v<ComputeDataType, F16> ||
|
||||
std::is_same_v<ComputeDataType, BF16> ||
|
||||
std::is_same_v<ComputeDataType, F32> || std::is_same_v<ComputeDataType, I8> ||
|
||||
std::is_same_v<ComputeDataType, I32> || std::is_same_v<ComputeDataType, int>,
|
||||
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
|
||||
|
||||
auto expo = std::log2(std::abs(max_possible_num));
|
||||
double compute_error = 0;
|
||||
if constexpr(std::is_same_v<ComputeDataType, I8> || std::is_same_v<ComputeDataType, I32> ||
|
||||
std::is_same_v<ComputeDataType, int>)
|
||||
if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
@@ -106,14 +94,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
|
||||
}
|
||||
|
||||
static_assert(std::is_same_v<OutDataType, F8> || std::is_same_v<OutDataType, F16> ||
|
||||
std::is_same_v<OutDataType, BF16> || std::is_same_v<OutDataType, F32> ||
|
||||
std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
|
||||
std::is_same_v<OutDataType, int>,
|
||||
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
|
||||
|
||||
double output_error = 0;
|
||||
if constexpr(std::is_same_v<OutDataType, I8> || std::is_same_v<OutDataType, I32> ||
|
||||
std::is_same_v<OutDataType, int>)
|
||||
if constexpr(is_any_of<OutDataType, I8, I32, int>::value)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
@@ -123,14 +108,11 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
|
||||
}
|
||||
double midway_error = std::max(compute_error, output_error);
|
||||
|
||||
static_assert(std::is_same_v<AccDataType, F8> || std::is_same_v<AccDataType, F16> ||
|
||||
std::is_same_v<AccDataType, BF16> || std::is_same_v<AccDataType, F32> ||
|
||||
std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
|
||||
std::is_same_v<AccDataType, int>,
|
||||
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
|
||||
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
|
||||
|
||||
double acc_error = 0;
|
||||
if constexpr(std::is_same_v<AccDataType, I8> || std::is_same_v<AccDataType, I32> ||
|
||||
std::is_same_v<AccDataType, int>)
|
||||
if constexpr(is_any_of<AccDataType, I8, I32, int>::value)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -14,57 +14,41 @@ namespace detail {
|
||||
template <typename OldLayout>
|
||||
CK_TILE_HOST std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
|
||||
{
|
||||
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCW> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCX> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKW>)
|
||||
using namespace ck_tile::tensor_layout::convolution;
|
||||
|
||||
if constexpr(is_any_of<OldLayout, GNCW, GKCX, GNKW>::value)
|
||||
{
|
||||
return {0, 1, 2, 3};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKHW>)
|
||||
else if constexpr(is_any_of<OldLayout, GNCHW, GKCYX, GNKHW>::value)
|
||||
{
|
||||
return {0, 1, 2, 3, 4};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCDHW> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCZYX> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
|
||||
else if constexpr(is_any_of<OldLayout, GNCDHW, GKCZYX, GNKDHW>::value)
|
||||
{
|
||||
return {0, 1, 2, 3, 4, 5};
|
||||
}
|
||||
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKXC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWK>)
|
||||
if constexpr(is_any_of<OldLayout, GNWC, GKXC, GNWK>::value)
|
||||
{
|
||||
return {0, 1, 3, 2};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWK>)
|
||||
else if constexpr(is_any_of<OldLayout, GNHWC, GKYXC, GNHWK>::value)
|
||||
{
|
||||
return {0, 1, 4, 2, 3};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKZYXC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
|
||||
else if constexpr(is_any_of<OldLayout, GNDHWC, GKZYXC, GNDHWK>::value)
|
||||
{
|
||||
return {0, 1, 5, 2, 3, 4};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KXGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGK>)
|
||||
else if constexpr(is_any_of<OldLayout, NWGC, KXGC, NWGK>::value)
|
||||
{
|
||||
return {2, 0, 3, 1};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGK>)
|
||||
else if constexpr(is_any_of<OldLayout, NHWGC, KYXGC, NHWGK>::value)
|
||||
{
|
||||
return {3, 0, 4, 1, 2};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KZYXGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
|
||||
else if constexpr(is_any_of<OldLayout, NDHWGC, KZYXGC, NDHWGK>::value)
|
||||
{
|
||||
return {4, 0, 5, 1, 2, 3};
|
||||
}
|
||||
@@ -83,11 +67,11 @@ template <typename InLayout>
|
||||
CK_TILE_HOST HostTensorDescriptor
|
||||
make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvParam& param)
|
||||
{
|
||||
using namespace ck_tile::tensor_layout::convolution;
|
||||
|
||||
std::vector<std::size_t> physical_lengths;
|
||||
|
||||
if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCW> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCDHW>)
|
||||
if constexpr(is_any_of<InLayout, GNCW, GNCHW, GNCDHW>::value)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
@@ -97,9 +81,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNWC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNDHWC>)
|
||||
else if constexpr(is_any_of<InLayout, GNWC, GNHWC, GNDHWC>::value)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
@@ -109,9 +91,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NWGC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NDHWGC>)
|
||||
else if constexpr(is_any_of<InLayout, NWGC, NHWGC, NDHWGC>::value)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
@@ -139,11 +119,11 @@ template <typename WeiLayout>
|
||||
CK_TILE_HOST HostTensorDescriptor
|
||||
make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvParam& param)
|
||||
{
|
||||
using namespace ck_tile::tensor_layout::convolution;
|
||||
|
||||
std::vector<std::size_t> physical_lengths;
|
||||
|
||||
if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXC>)
|
||||
if constexpr(is_any_of<WeiLayout, KXC, KYXC, KZYXC>::value)
|
||||
{
|
||||
if(param.G_ != 1)
|
||||
{
|
||||
@@ -157,9 +137,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCX> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCZYX>)
|
||||
else if constexpr(is_any_of<WeiLayout, GKCX, GKCYX, GKCZYX>::value)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_),
|
||||
@@ -169,9 +147,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKZYXC>)
|
||||
else if constexpr(is_any_of<WeiLayout, GKXC, GKYXC, GKZYXC>::value)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_),
|
||||
@@ -181,9 +157,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXGC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXGC>)
|
||||
else if constexpr(is_any_of<WeiLayout, KXGC, KYXGC, KZYXGC>::value)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
@@ -211,11 +185,11 @@ template <typename OutLayout>
|
||||
CK_TILE_HOST HostTensorDescriptor
|
||||
make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvParam& param)
|
||||
{
|
||||
using namespace ck_tile::tensor_layout::convolution;
|
||||
|
||||
std::vector<std::size_t> physical_lengths;
|
||||
|
||||
if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKW> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKHW> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
|
||||
if constexpr(is_any_of<OutLayout, GNKW, GNKHW, GNKDHW>::value)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
@@ -226,9 +200,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
// separate from legacy code above
|
||||
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNWK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNHWK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
|
||||
else if constexpr(is_any_of<OutLayout, GNWK, GNHWK, GNDHWK>::value)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
@@ -238,9 +210,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NWGK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NHWGK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
|
||||
else if constexpr(is_any_of<OutLayout, NWGK, NHWGK, NDHWGK>::value)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
|
||||
Reference in New Issue
Block a user