Change relu to clamp for grouped conv fwd instances (#2249)

This commit is contained in:
Bartłomiej Kocot
2025-05-29 00:51:25 +02:00
committed by GitHub
parent 6df1c56ad6
commit e7906dd644
34 changed files with 291 additions and 195 deletions

View File

@@ -115,6 +115,7 @@ using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
using AddSilu = ck::tensor_operation::element_wise::AddSilu;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;

View File

@@ -33,7 +33,7 @@ using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

View File

@@ -33,7 +33,7 @@ using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

View File

@@ -25,7 +25,7 @@ using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

View File

@@ -33,7 +33,7 @@ using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

View File

@@ -25,7 +25,7 @@ using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

View File

@@ -13,7 +13,7 @@
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_USE_XDL
#include "grouped_convolution_forward_bias_relu_xdl.inc"
#include "grouped_convolution_forward_bias_clamp_xdl.inc"
#endif
namespace ck {
@@ -44,7 +44,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddRelu,
ck::tensor_operation::element_wise::AddClamp,
AComputeType,
BComputeType>>
{
@@ -60,7 +60,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddRelu,
ck::tensor_operation::element_wise::AddClamp,
AComputeType,
BComputeType>;
@@ -80,23 +80,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
op_ptrs);
}
#endif
@@ -112,19 +112,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
op_ptrs);
}
#endif

View File

@@ -10,7 +10,7 @@ namespace instance {
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -22,9 +22,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -36,9 +36,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_in
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -50,9 +50,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -64,9 +64,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -78,9 +78,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_ins
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -92,9 +92,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -106,9 +106,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_par
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -120,9 +120,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intr
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -134,9 +134,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inte
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -148,9 +148,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_insta
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -162,9 +162,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -176,9 +176,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhw
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -190,9 +190,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndh
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -204,9 +204,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -218,9 +218,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -232,7 +232,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
#endif