From 99fe3df99ac8d544d4eebecb09b0c17ea40fa084 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Fri, 3 Oct 2025 14:23:04 +0000 Subject: [PATCH] Fix tensor descriptors. --- .../utils/transform_conv_bwd_weight_to_gemm.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp index 872316112f..03b6f153b1 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp @@ -445,7 +445,7 @@ struct TransformConvBwdWeightToGemm if constexpr (NumGroupsToMerge > 1) { - constexpr auto BatchStride = C_; + const auto BatchStride = C_; return make_naive_tensor_descriptor( make_tuple(N_, Wi_, NumGroupsToMerge, C_), make_tuple(NStride, WiStride, BatchStride, CStride)); @@ -551,7 +551,7 @@ struct TransformConvBwdWeightToGemm if constexpr (NumGroupsToMerge > 1) { - constexpr auto BatchStride = C_; + const auto BatchStride = C_; return make_naive_tensor_descriptor( make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), make_tuple(NStride, HiStride, WiStride, BatchStride, CStride)); @@ -632,7 +632,7 @@ struct TransformConvBwdWeightToGemm if constexpr (NumGroupsToMerge > 1) { - constexpr auto BatchStride = G_; + const auto BatchStride = G_; return make_naive_tensor_descriptor( make_tuple(K_, NumGroupsToMerge, N_ * Do_ * Ho_ * Wo_), make_tuple(KStride, BatchStride, NDoHoWoStride)); @@ -683,8 +683,8 @@ struct TransformConvBwdWeightToGemm // Add NumGroupsToMerge for Batch+M dimension and, 1 as a placehorder // for Batch+N dimension const auto desc = make_naive_tensor_descriptor( - make_tuple(NumGroupsToMerge, K_, Z_ * Y_ * X_ * C_), - make_tuple(BatchStride, KStride, ZYXStride, CStride)); + make_tuple(NumGroupsToMerge, K_, Z_ * Y_ * X_, 1, C_), + make_tuple(BatchStride, KStride, ZYXStride, BatchStride, CStride)); // Padd 1 to NumGroupsToMerge const auto padded_desc = transform_tensor_descriptor( desc, @@ -949,7 +949,7 @@ struct TransformConvBwdWeightToGemm make_tuple( make_pass_through_transform(N_), make_pad_transform(Di_, InLeftPadD_, InRightPadD_), - make_pad_transform(Wi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), make_pass_through_transform(NumGroupsToMerge), make_pass_through_transform(C_)),