mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK TILE] Add index optimizations for conv bwd weight (#3321)
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
ConvolutionSpecialization ConvolutionSpecialization,
|
||||
ConvolutionSpecialization ConvSpec,
|
||||
index_t VectorSizeA,
|
||||
index_t VectorSizeB,
|
||||
index_t VectorSizeC,
|
||||
@@ -440,21 +440,43 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t WiStride = G_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const auto BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStride, WiStride, BatchStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(WiStride, BatchStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Wi_, C_),
|
||||
make_tuple(WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStride, WiStride, BatchStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
|
||||
make_tuple(NStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_),
|
||||
make_tuple(NStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -551,21 +573,44 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t WiStride = G_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const auto BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), // K_Gm_N
|
||||
make_tuple(NStride, HiStride, WiStride, BatchStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_ * Hi_ * Wi_, NumGroupsToMerge, C_), // K_Gm_N
|
||||
make_tuple(WiStride, BatchStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Hi_ * Wi_, C_), // K_N
|
||||
make_tuple(WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), // K_N
|
||||
make_tuple(NStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), // K_Gm_N
|
||||
make_tuple(NStride, HiStride, WiStride, BatchStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Hi_, Wi_, C_), // K_N
|
||||
make_tuple(NStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -664,22 +709,44 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t WiStride = G_ * C_;
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_ * Di_ * Hi_ * Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(WiStride, BatchStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Di_ * Hi_ * Wi_, C_),
|
||||
make_tuple(WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const index_t BatchStride = C_;
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -755,83 +822,111 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
|
||||
|
||||
// B: input tensor comes in K_N
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// Output tensor transformation
|
||||
// [0, 1, 2] -> [0, 1]
|
||||
// [(N*Wo), Gm, K] -> [(N*Wo), (Gm*K)]
|
||||
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const auto out_grid_merged_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// Input tensor transformation, part 1.
|
||||
// [N, Wi, Gm, C] -> [N, (Wi + InLeftPadW + InRightPadW), Gm, C] = [N, Wip, Gm, C]
|
||||
const auto in_n_wip_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
const auto in_grid_merged_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, C_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// Input tensor transformation, part 2.
|
||||
// [N, Wip, Gm, C] -> [N, X, Wo, Gm, C]
|
||||
const auto in_n_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wip_gm_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{}));
|
||||
|
||||
// Input tensor transformation, part 3.
|
||||
// [0, 1, 2, 3, 4] -> [0, 1]
|
||||
// [N, X, Wo, Gm, C] -> [(N*Wo), (Gm*X*C)]
|
||||
const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor(
|
||||
in_n_x_wo_gm_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(X_, NumGroupsToMerge, C_)),
|
||||
make_merge_transform(make_tuple(N_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 4>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc);
|
||||
return make_tuple(out_grid_merged_desc, in_grid_merged_desc, wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(out_grid_desc, in_grid_desc, wei_grid_desc);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// [N, Wi, C] -> [N, (Wi + InLeftPadW + InRightPadW), C] = [N, Wip, C]
|
||||
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
// Output tensor transformation
|
||||
// [0, 1, 2] -> [0, 1]
|
||||
// [(N*Wo), Gm, K] -> [(N*Wo), (Gm*K)]
|
||||
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// [N, Wip, C] -> [N, X, Wo, C]
|
||||
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
// Input tensor transformation, part 1.
|
||||
// [N, Wi, Gm, C] -> [N, (Wi + InLeftPadW + InRightPadW), Gm, C] = [N, Wip, Gm, C]
|
||||
const auto in_n_wip_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc =
|
||||
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Wo_))),
|
||||
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
// Input tensor transformation, part 2.
|
||||
// [N, Wip, Gm, C] -> [N, X, Wo, Gm, C]
|
||||
const auto in_n_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wip_gm_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}, sequence<4>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
// Input tensor transformation, part 3.
|
||||
// [0, 1, 2, 3, 4] -> [0, 1]
|
||||
// [N, X, Wo, Gm, C] -> [(N*Wo), (Gm*X*C)]
|
||||
const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor(
|
||||
in_n_x_wo_gm_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(X_, NumGroupsToMerge, C_)),
|
||||
make_merge_transform(make_tuple(N_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 4>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
// [N, Wi, C] -> [N, (Wi + InLeftPadW + InRightPadW), C] = [N, Wip, C]
|
||||
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
// [N, Wip, C] -> [N, X, Wo, C]
|
||||
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
|
||||
in_n_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Wo_))),
|
||||
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -843,94 +938,122 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
|
||||
|
||||
// B: input tensor comes in K_N
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// Output tensor transformation
|
||||
// [0, 1, 2] -> [0, 1]
|
||||
// [(N*Ho*Wo), Gm, K] -> [(N*Ho*Wo), (K*Gm)]
|
||||
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Ho_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const auto out_grid_merged_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Ho_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// Input tensor transformation, part 1.
|
||||
// [N, Hi, Wi, Gm, C] -> [N, Hip, Wip, Gm, C]
|
||||
const auto in_n_hip_wip_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
const auto in_grid_merged_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Ho_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, C_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// Input tensor transformation, part 2.
|
||||
// [N, Hip, Wip, Gm, C] -> [N, (Y, Wo), (X, Wo), Gm, C]
|
||||
const auto in_n_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_gm_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5>{},
|
||||
sequence<6>{}));
|
||||
|
||||
// Input tensor transformation, part 3.
|
||||
// [0, 1, 2, 3, 4 5 6] -> [0, 1]
|
||||
// [N, Y, Ho, X, Wo, Gm, C] -> [(N*Ho*Wo), (Gm*Y*X*C)]
|
||||
const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_gm_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y_, X_, NumGroupsToMerge, C_)),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 6>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc);
|
||||
return make_tuple(out_grid_merged_desc, in_grid_merged_desc, wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(out_grid_desc, in_grid_desc, wei_grid_desc);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
// Output tensor transformation
|
||||
// [0, 1, 2] -> [0, 1]
|
||||
// [(N*Ho*Wo), Gm, K] -> [(N*Ho*Wo), (K*Gm)]
|
||||
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Ho_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
|
||||
// Input tensor transformation, part 1.
|
||||
// [N, Hi, Wi, Gm, C] -> [N, Hip, Wip, Gm, C]
|
||||
const auto in_n_hip_wip_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
// Input tensor transformation, part 2.
|
||||
// [N, Hip, Wip, Gm, C] -> [N, (Y, Wo), (X, Wo), Gm, C]
|
||||
const auto in_n_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_gm_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5>{},
|
||||
sequence<6>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
// Input tensor transformation, part 3.
|
||||
// [0, 1, 2, 3, 4 5 6] -> [0, 1]
|
||||
// [N, Y, Ho, X, Wo, Gm, C] -> [(N*Ho*Wo), (Gm*Y*X*C)]
|
||||
const auto in_gemm_n_gemm_k_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_gm_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y_, X_, NumGroupsToMerge, C_)),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 6>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
out_gemm_k_gemm_m_grid_desc, in_gemm_n_gemm_k_grid_desc, wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -942,120 +1065,148 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDimSpatial>();
|
||||
|
||||
// B: input tensor comes in K_N
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// Output tensor transformation
|
||||
// [0, 1, 2] -> [0, 1]
|
||||
// [(N*Do*Ho*Wo), Gm, K] -> [(N*Do*Ho*Wo), (K*Gm)]
|
||||
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
const auto out_grid_merged_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// Input tensor transformation, part 1.
|
||||
// [N, Di, Hi, Wi, Gm, C] -> [N, Dip, Hip, Wip, Gm, C]
|
||||
const auto in_n_dip_hip_wip_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}));
|
||||
const auto in_grid_merged_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, C_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// Input tensor transformation, part 2.
|
||||
// [N, Zip, Hip, Wip, Gm, C] -> [N, (Z, Zo), (Y, Wo), (X, Wo), Gm, C]
|
||||
const auto in_n_z_do_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_dip_hip_wip_gm_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Z_, Do_),
|
||||
make_tuple(ConvDilationD_, ConvStrideD_)),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5, 6>{},
|
||||
sequence<7>{},
|
||||
sequence<8>{}));
|
||||
|
||||
// Input tensor transformation, part 3.
|
||||
// [0, 1, 2, 3, 4, 5, 6, 7, 8] -> [0, 1]
|
||||
// [N, Z, Do, Y, Ho, X, Wo, Gm, C] -> [(N*Do*Ho*Wo), (Z*Y*X*Gm*C)]
|
||||
const auto in_gemm_k_gemm_n_grid_desc = transform_tensor_descriptor(
|
||||
in_n_z_do_y_ho_x_wo_gm_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, NumGroupsToMerge, C_)),
|
||||
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 7, 8>{}, sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
out_gemm_k_gemm_m_grid_desc, in_gemm_k_gemm_n_grid_desc, wei_grid_desc);
|
||||
return make_tuple(out_grid_merged_desc, in_grid_merged_desc, wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tuple(out_grid_desc, in_grid_desc, wei_grid_desc);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
if constexpr(NumGroupsToMerge > 1)
|
||||
{
|
||||
// Output tensor transformation
|
||||
// [0, 1, 2] -> [0, 1]
|
||||
// [(N*Do*Ho*Wo), Gm, K] -> [(N*Do*Ho*Wo), (K*Gm)]
|
||||
const auto out_gemm_k_gemm_m_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
|
||||
make_merge_transform(make_tuple(NumGroupsToMerge, K_))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Z_, Do_),
|
||||
make_tuple(ConvDilationD_, ConvStrideD_)),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5, 6>{},
|
||||
sequence<7>{}));
|
||||
// Input tensor transformation, part 1.
|
||||
// [N, Di, Hi, Wi, Gm, C] -> [N, Dip, Hip, Wip, Gm, C]
|
||||
const auto in_n_dip_hip_wip_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
// Input tensor transformation, part 2.
|
||||
// [N, Zip, Hip, Wip, Gm, C] -> [N, (Z, Zo), (Y, Wo), (X, Wo), Gm, C]
|
||||
const auto in_n_z_do_y_ho_x_wo_gm_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_dip_hip_wip_gm_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Z_, Do_),
|
||||
make_tuple(ConvDilationD_, ConvStrideD_)),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(NumGroupsToMerge),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5, 6>{},
|
||||
sequence<7>{},
|
||||
sequence<8>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
// Input tensor transformation, part 3.
|
||||
// [0, 1, 2, 3, 4, 5, 6, 7, 8] -> [0, 1]
|
||||
// [N, Z, Do, Y, Ho, X, Wo, Gm, C] -> [(N*Do*Ho*Wo), (Z*Y*X*Gm*C)]
|
||||
const auto in_gemm_k_gemm_n_grid_desc = transform_tensor_descriptor(
|
||||
in_n_z_do_y_ho_x_wo_gm_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, NumGroupsToMerge, C_)),
|
||||
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 7, 8>{}, sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
out_gemm_k_gemm_m_grid_desc, in_gemm_k_gemm_n_grid_desc, wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
|
||||
make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
|
||||
make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_),
|
||||
make_embed_transform(make_tuple(Z_, Do_),
|
||||
make_tuple(ConvDilationD_, ConvStrideD_)),
|
||||
make_embed_transform(make_tuple(Y_, Ho_),
|
||||
make_tuple(ConvDilationH_, ConvStrideH_)),
|
||||
make_embed_transform(make_tuple(X_, Wo_),
|
||||
make_tuple(ConvDilationW_, ConvStrideW_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(
|
||||
sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1, 2>{},
|
||||
sequence<3, 4>{},
|
||||
sequence<5, 6>{},
|
||||
sequence<7>{}));
|
||||
|
||||
const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user