|
|
|
|
@@ -41,8 +41,8 @@ namespace tensor_operation {
|
|
|
|
|
namespace device {
|
|
|
|
|
|
|
|
|
|
template <typename GridwiseGemm,
|
|
|
|
|
typename AGridDesc_AK0_M_K1,
|
|
|
|
|
typename BGridDesc_BK0_N_K1,
|
|
|
|
|
typename AGridDesc_M_K,
|
|
|
|
|
typename BGridDesc_N_K,
|
|
|
|
|
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
|
|
|
|
typename ComputePtrOffsetOfBatch,
|
|
|
|
|
bool HasMainKBlockLoop,
|
|
|
|
|
@@ -55,8 +55,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
|
|
|
|
#endif
|
|
|
|
|
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3(
|
|
|
|
|
typename GridwiseGemm::Argument karg,
|
|
|
|
|
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
|
|
|
|
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
|
|
|
|
const AGridDesc_M_K a_grid_desc_m_k,
|
|
|
|
|
const BGridDesc_N_K b_grid_desc_n_k,
|
|
|
|
|
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
|
|
|
|
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
|
|
|
|
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
|
|
|
|
@@ -67,14 +67,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
|
|
|
|
if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd)
|
|
|
|
|
{
|
|
|
|
|
#endif
|
|
|
|
|
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
|
|
|
|
typename GridwiseGemm::EpilogueCShuffle>();
|
|
|
|
|
using EpilogueType =
|
|
|
|
|
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
|
|
|
|
|
GridwiseGemm::UseDirectStore,
|
|
|
|
|
typename GridwiseGemm::EpilogueDirectStore,
|
|
|
|
|
typename GridwiseGemm::EpilogueCShuffle>::type;
|
|
|
|
|
|
|
|
|
|
constexpr index_t LDS_size =
|
|
|
|
|
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
|
|
|
|
|
__shared__ char p_shared[LDS_size];
|
|
|
|
|
|
|
|
|
|
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
|
|
|
|
auto epilogue_args = EpilogueType{};
|
|
|
|
|
|
|
|
|
|
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
|
|
|
|
|
BGridDesc_BK0_N_K1,
|
|
|
|
|
const auto a_grid_desc_ak0_m_ak1 =
|
|
|
|
|
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
|
|
|
|
|
|
|
|
|
|
const auto b_grid_desc_bk0_n_bk1 =
|
|
|
|
|
GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
|
|
|
|
|
|
|
|
|
|
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1),
|
|
|
|
|
decltype(b_grid_desc_bk0_n_bk1),
|
|
|
|
|
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
|
|
|
|
ComputePtrOffsetOfBatch,
|
|
|
|
|
1,
|
|
|
|
|
@@ -144,6 +156,7 @@ template <ck::index_t NDimSpatial,
|
|
|
|
|
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
|
|
|
|
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
|
|
|
|
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
|
|
|
|
bool UseThreadTileTransfer = true,
|
|
|
|
|
typename ComputeTypeA = InDataType,
|
|
|
|
|
typename ComputeTypeB = ComputeTypeA,
|
|
|
|
|
index_t MaxTransposeTransferSrcScalarPerVector = 1,
|
|
|
|
|
@@ -171,7 +184,15 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
using ADataType = OutDataType;
|
|
|
|
|
using BDataType = InDataType;
|
|
|
|
|
using CDataType = WeiDataType;
|
|
|
|
|
// // static const auto F1S1 =
|
|
|
|
|
// ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
|
|
|
|
// #if defined USE_WAVE
|
|
|
|
|
|
|
|
|
|
// static_assert(UseThreadTileTransfer==false &&
|
|
|
|
|
// (ConvBackwardWeightSpecialization==ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0
|
|
|
|
|
// ),"Only Filter1x1Stride1Pad0is supported for wavetile transfer"
|
|
|
|
|
// );
|
|
|
|
|
// #endif
|
|
|
|
|
// If NGCHW then ADataType must be equal to BDataType
|
|
|
|
|
static_assert(!(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
|
|
|
|
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
|
|
|
|
|
@@ -293,6 +314,33 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
batch);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Desc_K0_M_K1>
|
|
|
|
|
static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1)
|
|
|
|
|
{
|
|
|
|
|
const auto grid_desc_m_k = transform_tensor_descriptor(
|
|
|
|
|
desc_k0_m_k1,
|
|
|
|
|
make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)),
|
|
|
|
|
make_merge_transform(
|
|
|
|
|
make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))),
|
|
|
|
|
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
|
|
|
|
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
|
|
|
|
|
|
|
|
|
return grid_desc_m_k;
|
|
|
|
|
}
|
|
|
|
|
template <typename Desc_K0_N_K1>
|
|
|
|
|
static auto transform_k0_m_k1_to_n_k(const Desc_K0_N_K1& desc_k0_n_k1)
|
|
|
|
|
{
|
|
|
|
|
const auto grid_desc_n_k = transform_tensor_descriptor(
|
|
|
|
|
desc_k0_n_k1,
|
|
|
|
|
make_tuple(make_pass_through_transform(desc_k0_n_k1.GetLength(I1)),
|
|
|
|
|
make_merge_transform(
|
|
|
|
|
make_tuple(desc_k0_n_k1.GetLength(I0), desc_k0_n_k1.GetLength(I2)))),
|
|
|
|
|
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
|
|
|
|
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
|
|
|
|
|
|
|
|
|
return grid_desc_n_k;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
using NGCHWTransposeDescType =
|
|
|
|
|
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
|
|
|
|
.template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
|
|
|
|
|
@@ -308,9 +356,12 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
|
|
|
|
|
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
|
|
|
|
|
|
|
|
|
|
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
|
|
|
|
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
|
|
|
|
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
|
|
|
|
using AGridDesc_M_K_ = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
|
|
|
|
using BGridDesc_N_K_ = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
|
|
|
|
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
|
|
|
|
|
|
|
|
|
using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_M_K_{}));
|
|
|
|
|
using BGridDesc_N_K = decltype(transform_k0_m_k1_to_n_k(BGridDesc_N_K_{}));
|
|
|
|
|
|
|
|
|
|
using Block2TileMapTranspose = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
|
|
|
|
|
|
|
|
|
@@ -401,10 +452,10 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
BlkGemmPipelineVer,
|
|
|
|
|
ComputeTypeA,
|
|
|
|
|
ComputeTypeB,
|
|
|
|
|
false, // PermuteA
|
|
|
|
|
false, // permuteB
|
|
|
|
|
false, // IsBPreshuffle
|
|
|
|
|
true>; // ForceThreadTileTransfer
|
|
|
|
|
false, // PermuteA
|
|
|
|
|
false, // permuteB
|
|
|
|
|
false, // IsBPreshuffle
|
|
|
|
|
UseThreadTileTransfer>; // ForceThreadTileTransfer
|
|
|
|
|
|
|
|
|
|
// Argument
|
|
|
|
|
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
|
|
|
|
@@ -434,8 +485,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
&max_occupancy,
|
|
|
|
|
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
|
|
|
|
GridwiseGemm,
|
|
|
|
|
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
|
|
|
|
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
|
|
|
|
remove_reference_t<DeviceOp::AGridDesc_M_K>,
|
|
|
|
|
remove_reference_t<DeviceOp::BGridDesc_N_K>,
|
|
|
|
|
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
|
|
|
|
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
|
|
|
|
true,
|
|
|
|
|
@@ -473,8 +524,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
: p_a_grid_{p_out_grid},
|
|
|
|
|
p_b_grid_{p_in_grid},
|
|
|
|
|
p_c_grid_{p_wei_grid},
|
|
|
|
|
a_grid_desc_kbatch_k0_m_k1_{},
|
|
|
|
|
b_grid_desc_kbatch_k0_n_k1_{},
|
|
|
|
|
a_grid_desc_kbatch_m_k_{},
|
|
|
|
|
b_grid_desc_kbatch_n_k_{},
|
|
|
|
|
c_grid_desc_m_n_{},
|
|
|
|
|
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
|
|
|
|
compute_ptr_offset_of_batch_{},
|
|
|
|
|
@@ -572,16 +623,16 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
input_right_pads,
|
|
|
|
|
k_batch_);
|
|
|
|
|
|
|
|
|
|
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
|
|
|
|
|
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
|
|
|
|
|
c_grid_desc_m_n_ = descs[I2];
|
|
|
|
|
a_grid_desc_kbatch_m_k_ = transform_k0_m_k1_to_m_k(descs[I0]);
|
|
|
|
|
b_grid_desc_kbatch_n_k_ = transform_k0_m_k1_to_n_k(descs[I1]);
|
|
|
|
|
c_grid_desc_m_n_ = descs[I2];
|
|
|
|
|
|
|
|
|
|
// A/B/C Batch Stride
|
|
|
|
|
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
|
|
|
|
|
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
|
|
|
|
|
compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0];
|
|
|
|
|
const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
|
|
|
|
const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
|
|
|
|
const index_t GemmM = a_grid_desc_kbatch_m_k_.GetLength(I0);
|
|
|
|
|
const index_t GemmN = b_grid_desc_kbatch_n_k_.GetLength(I0);
|
|
|
|
|
|
|
|
|
|
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
|
|
|
|
GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
|
|
|
|
@@ -678,8 +729,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
const ADataType* p_a_grid_;
|
|
|
|
|
const BDataType* p_b_grid_;
|
|
|
|
|
CDataType* p_c_grid_;
|
|
|
|
|
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
|
|
|
|
|
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
|
|
|
|
|
AGridDesc_M_K a_grid_desc_kbatch_m_k_;
|
|
|
|
|
BGridDesc_N_K b_grid_desc_kbatch_n_k_;
|
|
|
|
|
CGridDesc_M_N c_grid_desc_m_n_;
|
|
|
|
|
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
|
|
|
|
|
|
|
|
|
@@ -724,17 +775,15 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
|
|
|
|
|
void ShowInfo(const Argument& arg)
|
|
|
|
|
{
|
|
|
|
|
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
|
|
|
|
|
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
|
|
|
|
|
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
|
|
|
|
|
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
|
|
|
|
|
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
|
|
|
|
|
std::cout << "arg.a_grid_desc_kbatch_m_k_{" << arg.a_grid_desc_kbatch_m_k_.GetLength(I0)
|
|
|
|
|
<< ", " << arg.a_grid_desc_kbatch_m_k_.GetLength(I1) << ", "
|
|
|
|
|
<< arg.a_grid_desc_kbatch_m_k_.GetLength(I2) << ", "
|
|
|
|
|
<< arg.a_grid_desc_kbatch_m_k_.GetLength(I3) << "}" << std::endl;
|
|
|
|
|
|
|
|
|
|
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
|
|
|
|
|
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
|
|
|
|
|
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
|
|
|
|
|
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
|
|
|
|
|
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
|
|
|
|
|
std::cout << "arg.b_grid_desc_kbatch_n_k_{" << arg.b_grid_desc_kbatch_n_k_.GetLength(I0)
|
|
|
|
|
<< ", " << arg.b_grid_desc_kbatch_n_k_.GetLength(I1) << ", "
|
|
|
|
|
<< arg.b_grid_desc_kbatch_n_k_.GetLength(I2) << ", "
|
|
|
|
|
<< arg.b_grid_desc_kbatch_n_k_.GetLength(I3) << "}" << std::endl;
|
|
|
|
|
|
|
|
|
|
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
|
|
|
|
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
|
|
|
|
@@ -744,10 +793,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
{
|
|
|
|
|
float ave_time = 0;
|
|
|
|
|
|
|
|
|
|
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
|
|
|
|
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
|
|
|
|
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
|
|
|
|
|
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
|
|
|
|
|
const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0);
|
|
|
|
|
const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0);
|
|
|
|
|
const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1);
|
|
|
|
|
|
|
|
|
|
const ADataType* p_a_grid = arg.p_a_grid_;
|
|
|
|
|
const BDataType* p_b_grid = arg.p_b_grid_;
|
|
|
|
|
@@ -839,10 +887,14 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock;
|
|
|
|
|
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
|
|
|
|
|
|
|
|
|
const auto num_k_per_block =
|
|
|
|
|
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
|
|
|
|
|
std::cout << "K0 value is:"
|
|
|
|
|
<< (GridwiseGemm::CalculateAK0Padded(
|
|
|
|
|
arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}), arg.k_batch_))
|
|
|
|
|
<< std::endl;
|
|
|
|
|
|
|
|
|
|
const auto clear_workspace = [&]() {
|
|
|
|
|
const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded(
|
|
|
|
|
arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}), arg.k_batch_));
|
|
|
|
|
const auto clear_workspace = [&]() {
|
|
|
|
|
hip_check_error(
|
|
|
|
|
hipMemsetAsync(p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
|
|
|
|
|
};
|
|
|
|
|
@@ -855,11 +907,11 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
typename GridwiseGemm::Argument gemm_arg_ = gemm_arg;
|
|
|
|
|
|
|
|
|
|
std::array<std::size_t, GridwiseGemm::NumATensor> size_as_buffers;
|
|
|
|
|
size_as_buffers[0] = arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() *
|
|
|
|
|
size_as_buffers[0] = arg.a_grid_desc_kbatch_m_k_.GetElementSpaceSize() *
|
|
|
|
|
sizeof(ADataType) / GridwiseGemm::APackedSize;
|
|
|
|
|
|
|
|
|
|
std::array<std::size_t, GridwiseGemm::NumBTensor> size_bs_buffers;
|
|
|
|
|
size_bs_buffers[0] = arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() *
|
|
|
|
|
size_bs_buffers[0] = arg.b_grid_desc_kbatch_n_k_.GetElementSpaceSize() *
|
|
|
|
|
sizeof(BDataType) / GridwiseGemm::BPackedSize;
|
|
|
|
|
|
|
|
|
|
std::array<std::size_t, GridwiseGemm::NumDTensor> size_ds_buffers;
|
|
|
|
|
@@ -889,8 +941,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
dim3(BlockSize),
|
|
|
|
|
0,
|
|
|
|
|
gemm_arg_,
|
|
|
|
|
arg.a_grid_desc_kbatch_k0_m_k1_,
|
|
|
|
|
arg.b_grid_desc_kbatch_k0_n_k1_,
|
|
|
|
|
arg.a_grid_desc_kbatch_m_k_,
|
|
|
|
|
arg.b_grid_desc_kbatch_n_k_,
|
|
|
|
|
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
|
|
|
|
arg.compute_ptr_offset_of_batch_,
|
|
|
|
|
num_k_per_block);
|
|
|
|
|
@@ -905,8 +957,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
dim3(BlockSize),
|
|
|
|
|
0,
|
|
|
|
|
gemm_arg,
|
|
|
|
|
arg.a_grid_desc_kbatch_k0_m_k1_,
|
|
|
|
|
arg.b_grid_desc_kbatch_k0_n_k1_,
|
|
|
|
|
arg.a_grid_desc_kbatch_m_k_,
|
|
|
|
|
arg.b_grid_desc_kbatch_n_k_,
|
|
|
|
|
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
|
|
|
|
arg.compute_ptr_offset_of_batch_,
|
|
|
|
|
num_k_per_block);
|
|
|
|
|
@@ -926,8 +978,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
{
|
|
|
|
|
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
|
|
|
|
GridwiseGemm,
|
|
|
|
|
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
|
|
|
|
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
|
|
|
|
remove_reference_t<DeviceOp::AGridDesc_M_K>,
|
|
|
|
|
remove_reference_t<DeviceOp::BGridDesc_N_K>,
|
|
|
|
|
remove_reference_t<
|
|
|
|
|
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
|
|
|
|
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
|
|
|
|
@@ -940,8 +992,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
{
|
|
|
|
|
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
|
|
|
|
GridwiseGemm,
|
|
|
|
|
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
|
|
|
|
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
|
|
|
|
remove_reference_t<DeviceOp::AGridDesc_M_K>,
|
|
|
|
|
remove_reference_t<DeviceOp::BGridDesc_N_K>,
|
|
|
|
|
remove_reference_t<
|
|
|
|
|
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
|
|
|
|
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
|
|
|
|
@@ -965,8 +1017,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
{
|
|
|
|
|
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
|
|
|
|
GridwiseGemm,
|
|
|
|
|
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
|
|
|
|
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
|
|
|
|
remove_reference_t<DeviceOp::AGridDesc_M_K>,
|
|
|
|
|
remove_reference_t<DeviceOp::BGridDesc_N_K>,
|
|
|
|
|
remove_reference_t<
|
|
|
|
|
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
|
|
|
|
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
|
|
|
|
@@ -979,8 +1031,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
{
|
|
|
|
|
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
|
|
|
|
GridwiseGemm,
|
|
|
|
|
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
|
|
|
|
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
|
|
|
|
remove_reference_t<DeviceOp::AGridDesc_M_K>,
|
|
|
|
|
remove_reference_t<DeviceOp::BGridDesc_N_K>,
|
|
|
|
|
remove_reference_t<
|
|
|
|
|
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
|
|
|
|
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
|
|
|
|
@@ -1042,10 +1094,15 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
|
|
|
|
|
|
|
|
|
|
static bool IsSupportedArgument(const Argument& arg)
|
|
|
|
|
{
|
|
|
|
|
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
|
|
|
|
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
|
|
|
|
|
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
|
|
|
|
|
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
|
|
|
|
|
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
|
|
|
|
|
if(arg.k_batch_ < 0)
|
|
|
|
|
{
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0);
|
|
|
|
|
const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0);
|
|
|
|
|
const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1);
|
|
|
|
|
|
|
|
|
|
typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{nullptr}, // p_as_grid
|
|
|
|
|
std::array<const void*, 1>{nullptr}, // p_bs_grid
|
|
|
|
|
|