[CK_TILE] Switch into universal gemms for conv bwds (#2981)

* switch into universal gemms for conv bwds

* some fixes and support universal gemm in conv fwd

* add reviewer comments
This commit is contained in:
jakpiase
2025-10-14 16:09:16 +02:00
committed by GitHub
parent 589e242eda
commit 6deaaa92cc
19 changed files with 1043 additions and 550 deletions

View File

@@ -82,20 +82,14 @@ struct GroupedConvTraits
true,
true,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor,
// TODO: Change to and enable vector load
// ck_tile::tensor_layout::gemm::RowMajor,
// ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
using GroupedConvImplicitGemmTraitsBwdWeight =
TileGemmTraits<true,
true,
true,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor,
// TODO: Change to and enable vector load
// ck_tile::tensor_layout::gemm::ColumnMajor,
// ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::RowMajor>;
static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;

View File

@@ -502,7 +502,7 @@ struct TransformConvBwdDataToGemm
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_),
make_tuple(NStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
number<VectorSizeC>{},
I1);
}
@@ -512,7 +512,7 @@ struct TransformConvBwdDataToGemm
// GKYXC
return make_naive_tensor_descriptor(make_tuple(K_, Y_, X_, C_),
make_tuple(C_ * X_ * Y_, C_ * X_, C_, I1),
number<VectorSizeC>{},
number<VectorSizeB>{},
I1);
}
@@ -547,7 +547,7 @@ struct TransformConvBwdDataToGemm
return make_naive_tensor_descriptor(
make_tuple(N_, Di_, Hi_, Wi_, C_),
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
number<VectorSizeC>{},
I1);
}
@@ -558,7 +558,7 @@ struct TransformConvBwdDataToGemm
return make_naive_tensor_descriptor(
make_tuple(K_, Z_, Y_, X_, C_),
make_tuple(C_ * X_ * Y_ * Z_, C_ * X_ * Y_, C_ * X_, C_, I1),
number<VectorSizeC>{},
number<VectorSizeB>{},
I1);
}
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
@@ -642,7 +642,7 @@ struct TransformConvBwdDataToGemm
make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)),
make_pass_through_transform(C_)),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
make_tuple(sequence<0>{}, sequence<1>{}));
// c: input
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
@@ -797,7 +797,7 @@ struct TransformConvBwdDataToGemm
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
make_pass_through_transform(C_)),
make_tuple(sequence<1, 2, 0>{}, sequence<3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
make_tuple(sequence<0>{}, sequence<1>{}));
// c: input
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
@@ -999,7 +999,7 @@ struct TransformConvBwdDataToGemm
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
make_pass_through_transform(C_)),
make_tuple(sequence<1, 2, 3, 0>{}, sequence<4>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
make_tuple(sequence<0>{}, sequence<1>{}));
// c: input
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(

View File

@@ -421,7 +421,6 @@ struct TransformConvBwdWeightToGemm
constexpr auto KStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
make_tuple(KStride, NDoHoWoStride),
number<VectorSizeA>{},
@@ -463,9 +462,8 @@ struct TransformConvBwdWeightToGemm
constexpr auto KStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_),
make_tuple(KStride, NDoHoWoStride),
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), // K_M
make_tuple(NDoHoWoStride, KStride),
number<VectorSizeA>{},
I1);
}
@@ -480,7 +478,7 @@ struct TransformConvBwdWeightToGemm
constexpr auto CStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_),
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), // K_N
make_tuple(NStride, HiStride, WiStride, CStride),
number<VectorSizeB>{},
I1);
@@ -506,9 +504,8 @@ struct TransformConvBwdWeightToGemm
constexpr auto KStride = I1;
// TODO Add support for NumGroupsToMerge > 1
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_),
make_tuple(KStride, NDoHoWoStride),
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
make_tuple(NDoHoWoStride, KStride),
number<VectorSizeA>{},
I1);
}
@@ -577,7 +574,7 @@ struct TransformConvBwdWeightToGemm
make_tuple(make_merge_transform(make_tuple(X_, C_)),
make_merge_transform(make_tuple(N_, Wo_))),
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}
@@ -614,7 +611,7 @@ struct TransformConvBwdWeightToGemm
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}
@@ -657,7 +654,7 @@ struct TransformConvBwdWeightToGemm
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)),
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
make_tuple(sequence<1>{}, sequence<0>{}));
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
}