mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Fix tensor descriptors.
This commit is contained in:
@@ -445,7 +445,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
if constexpr (NumGroupsToMerge > 1)
|
||||
{
|
||||
constexpr auto BatchStride = C_;
|
||||
const auto BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStride, WiStride, BatchStride, CStride));
|
||||
@@ -551,7 +551,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
if constexpr (NumGroupsToMerge > 1)
|
||||
{
|
||||
constexpr auto BatchStride = C_;
|
||||
const auto BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStride, HiStride, WiStride, BatchStride, CStride));
|
||||
@@ -632,7 +632,7 @@ struct TransformConvBwdWeightToGemm
|
||||
|
||||
if constexpr (NumGroupsToMerge > 1)
|
||||
{
|
||||
constexpr auto BatchStride = G_;
|
||||
const auto BatchStride = G_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, NumGroupsToMerge, N_ * Do_ * Ho_ * Wo_),
|
||||
make_tuple(KStride, BatchStride, NDoHoWoStride));
|
||||
@@ -683,8 +683,8 @@ struct TransformConvBwdWeightToGemm
|
||||
// Add NumGroupsToMerge for Batch+M dimension and, 1 as a placehorder
|
||||
// for Batch+N dimension
|
||||
const auto desc = make_naive_tensor_descriptor(
|
||||
make_tuple(NumGroupsToMerge, K_, Z_ * Y_ * X_ * C_),
|
||||
make_tuple(BatchStride, KStride, ZYXStride, CStride));
|
||||
make_tuple(NumGroupsToMerge, K_, Z_ * Y_ * X_, 1, C_),
|
||||
make_tuple(BatchStride, KStride, ZYXStride, BatchStride, CStride));
|
||||
// Padd 1 to NumGroupsToMerge
|
||||
const auto padded_desc = transform_tensor_descriptor(
|
||||
desc,
|
||||
@@ -949,7 +949,7 @@ struct TransformConvBwdWeightToGemm
|
||||
make_tuple(
|
||||
make_pass_through_transform(N_),
|
||||
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
|
||||
make_pad_transform(Wi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
|
||||
Reference in New Issue
Block a user