mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Revert "Add padding to 1x1Stride1Pad0 conv specialization (grouped conv bwd weight) (#2610)" (#2637)
This reverts commit 2203b0ddfe.
Co-authored-by: Bartłomiej Kocot <barkocot@amd.com>
This commit is contained in:
@@ -222,6 +222,9 @@
|
||||
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
|
||||
#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
|
||||
|
||||
// workaround: conv crash when K, C is even
|
||||
#define CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN 1
|
||||
|
||||
// workaround: compiler crash when compiling recursive lambda
|
||||
#define CK_WORKAROUND_SWDEV_275126 1
|
||||
|
||||
|
||||
@@ -331,8 +331,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3<
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -1299,6 +1299,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// workaround: disable when K, C is even
|
||||
#if CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN
|
||||
if(arg.Conv_C_ % 2 == 0 || arg.Conv_K_ % 2 == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
// check if it's 1x1, stride=1 pad = 0 conv
|
||||
for(int i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
@@ -1323,7 +1330,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
return GridwiseGemm::CheckValidity(gemm_arg);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
@@ -607,203 +606,6 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
__host__ static constexpr bool CheckValidity(const Argument& karg)
|
||||
{
|
||||
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
|
||||
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
|
||||
"Invalid tuning param!");
|
||||
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
|
||||
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
|
||||
{
|
||||
if(!(karg.M % MPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
|
||||
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
|
||||
{
|
||||
if(!(karg.N % NPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
|
||||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding))
|
||||
{
|
||||
|
||||
auto K_t = karg.KBatch * KPerBlock;
|
||||
if(!(karg.K % K_t == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
|
||||
<< karg.K << " " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
|
||||
auto K_t = karg.KBatch * KReadVec;
|
||||
auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
|
||||
if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
if(karg.K % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(karg.M % ABlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
|
||||
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
if(karg.N % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(karg.K % BBlockTransferSrcScalarPerVector != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K (" << karg.K
|
||||
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
|
||||
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
|
||||
{
|
||||
if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg N (" << karg.N
|
||||
<< ") value is not a multiple of "
|
||||
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
|
||||
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg M (" << karg.M
|
||||
<< ") value is not a multiple of "
|
||||
"CShuffleBlockTransferScalarPerVector_NPerBlock ("
|
||||
<< CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
|
||||
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, float>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
|
||||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
|
||||
{
|
||||
if(!karg.IsReduceAdd())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
if(karg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check gridwise gemm pipeline
|
||||
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
|
||||
|
||||
if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
@@ -192,7 +192,7 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -210,7 +210,7 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -218,17 +218,9 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
|
||||
|
||||
// Padd
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -248,7 +240,7 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -287,7 +279,7 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -296,6 +288,26 @@ struct TransformConvBwdWeightToGemm
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
|
||||
|
||||
// Padd
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch),
|
||||
make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch),
|
||||
make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmN, PadGemmN),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
@@ -303,8 +315,8 @@ struct TransformConvBwdWeightToGemm
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
}
|
||||
}
|
||||
@@ -380,7 +392,7 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -395,21 +407,13 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// Padd
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -424,7 +428,7 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -465,11 +469,31 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// Padd
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch),
|
||||
make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch),
|
||||
make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmN, PadGemmN),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
@@ -477,8 +501,8 @@ struct TransformConvBwdWeightToGemm
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
}
|
||||
}
|
||||
@@ -561,7 +585,7 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -576,21 +600,13 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// Padd
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -605,7 +621,7 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -655,11 +671,31 @@ struct TransformConvBwdWeightToGemm
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// Padd
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch),
|
||||
make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch),
|
||||
make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmN, PadGemmN),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
@@ -667,8 +703,8 @@ struct TransformConvBwdWeightToGemm
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
}
|
||||
} // function end
|
||||
|
||||
@@ -374,7 +374,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
@@ -390,21 +390,13 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// Padd
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -420,7 +412,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
@@ -461,11 +453,29 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// Padd
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
|
||||
make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
|
||||
make_right_pad_transform(GemmN, PadGemmN),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
@@ -473,8 +483,8 @@ struct TransformConvBwdWeightToGemmV2
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
}
|
||||
|
||||
@@ -552,7 +562,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
@@ -568,21 +578,13 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// Padd
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -598,7 +600,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
@@ -648,11 +650,29 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// Padd
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
|
||||
make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
|
||||
make_right_pad_transform(GemmN, PadGemmN),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
@@ -660,8 +680,8 @@ struct TransformConvBwdWeightToGemmV2
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
}
|
||||
}
|
||||
@@ -745,7 +765,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
@@ -761,21 +781,13 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// Padd
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -791,7 +803,7 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmM, PadGemmM)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
@@ -856,11 +868,29 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// Padd
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
|
||||
make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
|
||||
make_right_pad_transform(GemmN, PadGemmN),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
@@ -868,8 +898,8 @@ struct TransformConvBwdWeightToGemmV2
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
}
|
||||
} // function end
|
||||
|
||||
Reference in New Issue
Block a user