Fix tensor descriptors.

This commit is contained in:
Ville Pietilä
2025-10-03 14:23:04 +00:00
parent 9510171377
commit 99fe3df99a

View File

@@ -445,7 +445,7 @@ struct TransformConvBwdWeightToGemm
if constexpr (NumGroupsToMerge > 1)
{
constexpr auto BatchStride = C_;
const auto BatchStride = C_;
return make_naive_tensor_descriptor(
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStride, WiStride, BatchStride, CStride));
@@ -551,7 +551,7 @@ struct TransformConvBwdWeightToGemm
if constexpr (NumGroupsToMerge > 1)
{
constexpr auto BatchStride = C_;
const auto BatchStride = C_;
return make_naive_tensor_descriptor(
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
make_tuple(NStride, HiStride, WiStride, BatchStride, CStride));
@@ -632,7 +632,7 @@ struct TransformConvBwdWeightToGemm
if constexpr (NumGroupsToMerge > 1)
{
constexpr auto BatchStride = G_;
const auto BatchStride = G_;
return make_naive_tensor_descriptor(
make_tuple(K_, NumGroupsToMerge, N_ * Do_ * Ho_ * Wo_),
make_tuple(KStride, BatchStride, NDoHoWoStride));
@@ -683,8 +683,8 @@ struct TransformConvBwdWeightToGemm
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placehorder
// for Batch+N dimension
const auto desc = make_naive_tensor_descriptor(
make_tuple(NumGroupsToMerge, K_, Z_ * Y_ * X_ * C_),
make_tuple(BatchStride, KStride, ZYXStride, CStride));
make_tuple(NumGroupsToMerge, K_, Z_ * Y_ * X_, 1, C_),
make_tuple(BatchStride, KStride, ZYXStride, BatchStride, CStride));
// Padd 1 to NumGroupsToMerge
const auto padded_desc = transform_tensor_descriptor(
desc,
@@ -949,7 +949,7 @@ struct TransformConvBwdWeightToGemm
make_tuple(
make_pass_through_transform(N_),
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
make_pad_transform(Wi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
make_pass_through_transform(NumGroupsToMerge),
make_pass_through_transform(C_)),