re-enable clang-format by default (#3030)

* re-enable clang-format by default

* fix clang format
This commit is contained in:
Illia Silin
2025-10-15 07:43:11 -07:00
committed by GitHub
parent bde5f26db3
commit 3348f01e6f
7 changed files with 126 additions and 124 deletions

View File

@@ -44,13 +44,13 @@ struct GroupedConvBwdDataKernelArgs
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -145,15 +145,15 @@ struct GroupedConvBwdDataKernelArgs
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -161,13 +161,13 @@ struct GroupedConvBwdDataKernelArgs
static_cast<index_t>(args.output_spatial_lengths_[1])};
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
static_cast<index_t>(args.conv_filter_strides_[1])};
static_cast<index_t>(args.conv_filter_strides_[1])};
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
static_cast<index_t>(args.conv_filter_dilations_[1])};
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
static_cast<index_t>(args.input_left_pads_[1])};
static_cast<index_t>(args.input_left_pads_[1])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1])};
static_cast<index_t>(args.input_right_pads_[1])};
k_batch = args.k_batch;
@@ -262,17 +262,17 @@ struct GroupedConvBwdDataKernelArgs
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1]),
static_cast<index_t>(args.input_spatial_lengths_[2])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1]),
static_cast<index_t>(args.input_spatial_lengths_[2])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1]),
static_cast<index_t>(args.filter_spatial_lengths_[2])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1]),
static_cast<index_t>(args.filter_spatial_lengths_[2])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -281,17 +281,17 @@ struct GroupedConvBwdDataKernelArgs
static_cast<index_t>(args.output_spatial_lengths_[2])};
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
static_cast<index_t>(args.conv_filter_strides_[1]),
static_cast<index_t>(args.conv_filter_strides_[2])};
static_cast<index_t>(args.conv_filter_strides_[1]),
static_cast<index_t>(args.conv_filter_strides_[2])};
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
static_cast<index_t>(args.conv_filter_dilations_[1]),
static_cast<index_t>(args.conv_filter_dilations_[2])};
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
static_cast<index_t>(args.input_left_pads_[1]),
static_cast<index_t>(args.input_left_pads_[2])};
static_cast<index_t>(args.input_left_pads_[1]),
static_cast<index_t>(args.input_left_pads_[2])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1]),
static_cast<index_t>(args.input_right_pads_[2])};
static_cast<index_t>(args.input_right_pads_[1]),
static_cast<index_t>(args.input_right_pads_[2])};
k_batch = args.k_batch;
@@ -387,8 +387,8 @@ struct GroupedConvBwdDataKernelArgs
static constexpr index_t MaxGroupedGemmGroupsNum = 128;
using ABCGridDescs = remove_cvref_t<decltype(
ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))>;
using ABCGridDescs = remove_cvref_t<
decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))>;
using AGridDescMK = remove_cvref_t<decltype(ABCGridDescs{}[number<0>{}])>;
using BGridDescNK = remove_cvref_t<decltype(ABCGridDescs{}[number<1>{}])>;

View File

@@ -40,13 +40,13 @@ struct GroupedConvBwdWeightKernelArgs
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -109,15 +109,15 @@ struct GroupedConvBwdWeightKernelArgs
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -125,13 +125,13 @@ struct GroupedConvBwdWeightKernelArgs
static_cast<index_t>(args.output_spatial_lengths_[1])};
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
static_cast<index_t>(args.conv_filter_strides_[1])};
static_cast<index_t>(args.conv_filter_strides_[1])};
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
static_cast<index_t>(args.conv_filter_dilations_[1])};
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
static_cast<index_t>(args.input_left_pads_[1])};
static_cast<index_t>(args.input_left_pads_[1])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1])};
static_cast<index_t>(args.input_right_pads_[1])};
k_batch = args.k_batch;
@@ -185,17 +185,17 @@ struct GroupedConvBwdWeightKernelArgs
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1]),
static_cast<index_t>(args.input_spatial_lengths_[2])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1]),
static_cast<index_t>(args.input_spatial_lengths_[2])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1]),
static_cast<index_t>(args.filter_spatial_lengths_[2])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1]),
static_cast<index_t>(args.filter_spatial_lengths_[2])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -204,17 +204,17 @@ struct GroupedConvBwdWeightKernelArgs
static_cast<index_t>(args.output_spatial_lengths_[2])};
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
static_cast<index_t>(args.conv_filter_strides_[1]),
static_cast<index_t>(args.conv_filter_strides_[2])};
static_cast<index_t>(args.conv_filter_strides_[1]),
static_cast<index_t>(args.conv_filter_strides_[2])};
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
static_cast<index_t>(args.conv_filter_dilations_[1]),
static_cast<index_t>(args.conv_filter_dilations_[2])};
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
static_cast<index_t>(args.input_left_pads_[1]),
static_cast<index_t>(args.input_left_pads_[2])};
static_cast<index_t>(args.input_left_pads_[1]),
static_cast<index_t>(args.input_left_pads_[2])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1]),
static_cast<index_t>(args.input_right_pads_[2])};
static_cast<index_t>(args.input_right_pads_[1]),
static_cast<index_t>(args.input_right_pads_[2])};
k_batch = args.k_batch;
@@ -257,8 +257,8 @@ struct GroupedConvBwdWeightKernelArgs
GemmBatch = args.G_;
}
using ABCGridDescs = remove_cvref_t<decltype(
ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())>;
using ABCGridDescs = remove_cvref_t<
decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())>;
using AGridDescKM = remove_cvref_t<decltype(ABCGridDescs{}[number<0>{}])>;
using BGridDescKN = remove_cvref_t<decltype(ABCGridDescs{}[number<1>{}])>;

View File

@@ -41,13 +41,13 @@ struct GroupedConvFwdKernelArgs
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -124,15 +124,15 @@ struct GroupedConvFwdKernelArgs
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -140,13 +140,13 @@ struct GroupedConvFwdKernelArgs
static_cast<index_t>(args.output_spatial_lengths_[1])};
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
static_cast<index_t>(args.conv_filter_strides_[1])};
static_cast<index_t>(args.conv_filter_strides_[1])};
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
static_cast<index_t>(args.conv_filter_dilations_[1])};
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
static_cast<index_t>(args.input_left_pads_[1])};
static_cast<index_t>(args.input_left_pads_[1])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1])};
static_cast<index_t>(args.input_right_pads_[1])};
k_batch = args.k_batch;
@@ -216,17 +216,17 @@ struct GroupedConvFwdKernelArgs
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1]),
static_cast<index_t>(args.input_spatial_lengths_[2])};
static_cast<index_t>(args.N_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.input_spatial_lengths_[0]),
static_cast<index_t>(args.input_spatial_lengths_[1]),
static_cast<index_t>(args.input_spatial_lengths_[2])};
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1]),
static_cast<index_t>(args.filter_spatial_lengths_[2])};
static_cast<index_t>(args.K_),
static_cast<index_t>(args.C_),
static_cast<index_t>(args.filter_spatial_lengths_[0]),
static_cast<index_t>(args.filter_spatial_lengths_[1]),
static_cast<index_t>(args.filter_spatial_lengths_[2])};
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
static_cast<index_t>(args.K_),
@@ -235,17 +235,17 @@ struct GroupedConvFwdKernelArgs
static_cast<index_t>(args.output_spatial_lengths_[2])};
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
static_cast<index_t>(args.conv_filter_strides_[1]),
static_cast<index_t>(args.conv_filter_strides_[2])};
static_cast<index_t>(args.conv_filter_strides_[1]),
static_cast<index_t>(args.conv_filter_strides_[2])};
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
static_cast<index_t>(args.conv_filter_dilations_[1]),
static_cast<index_t>(args.conv_filter_dilations_[2])};
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
static_cast<index_t>(args.input_left_pads_[1]),
static_cast<index_t>(args.input_left_pads_[2])};
static_cast<index_t>(args.input_left_pads_[1]),
static_cast<index_t>(args.input_left_pads_[2])};
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
static_cast<index_t>(args.input_right_pads_[1]),
static_cast<index_t>(args.input_right_pads_[2])};
static_cast<index_t>(args.input_right_pads_[1]),
static_cast<index_t>(args.input_right_pads_[2])};
k_batch = args.k_batch;
@@ -306,15 +306,15 @@ struct GroupedConvFwdKernelArgs
args.output_spatial_lengths_[2];
}
using AGridDescMK = remove_cvref_t<decltype(
ConvToGemmFwdTransformer{}
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>())>;
using BGridDescNK = remove_cvref_t<decltype(
ConvToGemmFwdTransformer{}
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>())>;
using CGridDescMN = remove_cvref_t<decltype(
ConvToGemmFwdTransformer{}
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>())>;
using AGridDescMK = remove_cvref_t<
decltype(ConvToGemmFwdTransformer{}
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>())>;
using BGridDescNK = remove_cvref_t<
decltype(ConvToGemmFwdTransformer{}
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>())>;
using CGridDescMN = remove_cvref_t<
decltype(ConvToGemmFwdTransformer{}
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>())>;
static constexpr index_t NonSpatialDims = 3;
array<index_t, NonSpatialDims + GroupedConvTraitsType_::NDimSpatial> in_g_n_c_wis_lengths;