mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Grouped conv bwd data NGCHW (#1967)
* Grouped conv bwd data NGCHW * fixes * fix * Improvements * Fix * Fix * add client example
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
@@ -13,7 +14,9 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -202,9 +205,11 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename AComputeType = ADataType,
|
||||
typename BComputeType = AComputeType>
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename AComputeType = ADataType,
|
||||
typename BComputeType = AComputeType,
|
||||
index_t MaxTransposeTransferInScalarPerVector = 1,
|
||||
index_t MaxTransposeTransferOutScalarPerVector = 1>
|
||||
struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
: public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
|
||||
ALayout, // output image
|
||||
@@ -237,6 +242,19 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
using ALayoutAfterTranspose =
|
||||
std::conditional_t<is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NHWGK,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NDHWGK,
|
||||
ALayout>>;
|
||||
using ELayoutAfterTranspose =
|
||||
std::conditional_t<is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NHWGC,
|
||||
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>(),
|
||||
tensor_layout::convolution::NDHWGC,
|
||||
ELayout>>;
|
||||
|
||||
using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
AK1,
|
||||
@@ -246,9 +264,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
KPerBlock,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN,
|
||||
ALayout,
|
||||
ALayoutAfterTranspose,
|
||||
BLayout,
|
||||
ELayout,
|
||||
ELayoutAfterTranspose,
|
||||
true, /*SplitConvN*/
|
||||
ABDataType,
|
||||
EDataType>;
|
||||
@@ -274,7 +292,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
KPerBlock,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN,
|
||||
ALayout,
|
||||
ALayoutAfterTranspose,
|
||||
BLayout,
|
||||
DLayout,
|
||||
true, /*SplitConvN*/
|
||||
@@ -374,7 +392,70 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, MPerBlock>;
|
||||
|
||||
static constexpr index_t ClusterLengthMPerBlock =
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
|
||||
static constexpr index_t ClusterLengthNPerBlock =
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
|
||||
|
||||
static constexpr auto conv_ngchw_to_nhwgc_transformer =
|
||||
TransformConvNGCHWToNHWGC<ELayout,
|
||||
BLayout,
|
||||
ALayout,
|
||||
NDimSpatial,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
MPerBlock / ClusterLengthMPerBlock>{};
|
||||
|
||||
static constexpr index_t TransposeTransferInScalarPerVectorAligned =
|
||||
std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferInScalarPerVector);
|
||||
static constexpr index_t TransposeTransferOutScalarPerVectorAligned =
|
||||
std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferOutScalarPerVector);
|
||||
|
||||
using NGCHWTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
using NHWGCTransposeDescType =
|
||||
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
|
||||
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
|
||||
|
||||
static constexpr index_t ElementwiseBlocksize = ClusterLengthMPerBlock * ClusterLengthNPerBlock;
|
||||
|
||||
using GridwiseElementwiseInputTranspose =
|
||||
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<const ADataType*>,
|
||||
Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
ElementwiseBlocksize,
|
||||
NPerBlock,
|
||||
MPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<TransposeTransferInScalarPerVectorAligned>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
I1,
|
||||
I0>;
|
||||
|
||||
using GridwiseElementwiseOutputTranspose =
|
||||
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
|
||||
Tuple<NGCHWTransposeDescType>,
|
||||
Tuple<const EDataType*>,
|
||||
Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
ElementwiseBlocksize,
|
||||
NPerBlock,
|
||||
MPerBlock,
|
||||
NPerBlock / ClusterLengthNPerBlock,
|
||||
MPerBlock / ClusterLengthMPerBlock,
|
||||
Sequence<1, 0>,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
Sequence<TransposeTransferOutScalarPerVectorAligned>,
|
||||
I0,
|
||||
I1>;
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
@@ -409,10 +490,18 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
cde_element_op_{cde_element_op},
|
||||
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeStrides(e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides);
|
||||
|
||||
// populate Ds pointer
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
@@ -491,17 +580,18 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
|
||||
}
|
||||
|
||||
ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes};
|
||||
ConvToGemmBwdDataTransform conv_to_gemm_transform_{
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides_transposed,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes};
|
||||
|
||||
conv_N_per_block_ = conv_to_gemm_transform_.N_;
|
||||
|
||||
@@ -527,7 +617,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
KPerBlock,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN,
|
||||
ALayout,
|
||||
ALayoutAfterTranspose,
|
||||
BLayout,
|
||||
DLayout,
|
||||
true, /*SplitConvN*/
|
||||
@@ -535,7 +625,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
DDataType>;
|
||||
ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
a_g_n_k_wos_strides_transposed,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
ds_g_n_c_wis_lengths[i],
|
||||
@@ -591,12 +681,73 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
}
|
||||
// A/B/Ds/E Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides_transposed[0];
|
||||
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ =
|
||||
a_g_n_k_wos_strides_transposed[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ =
|
||||
e_g_n_c_wis_strides_transposed[1] * conv_N_per_block_;
|
||||
|
||||
num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
// Use not modified base strides
|
||||
a_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_);
|
||||
a_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_);
|
||||
|
||||
e_in_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
|
||||
e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_);
|
||||
e_out_transpose_desc_ =
|
||||
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
|
||||
e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_);
|
||||
|
||||
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
|
||||
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
|
||||
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{
|
||||
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
|
||||
|
||||
compute_ptr_offset_of_workspace_n_.BatchStrideA_ =
|
||||
a_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_workspace_n_.BatchStrideE_ =
|
||||
e_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
|
||||
a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(ADataType) * a_acum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
|
||||
e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(EDataType) * e_accum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
// Transpose require workspace for A and B
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
@@ -645,10 +796,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
// block-to-e-tile map
|
||||
std::vector<Block2ETileMap> block_2_etile_map_container_;
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_e_;
|
||||
|
||||
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
|
||||
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_workspace_n_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOp a_element_op_;
|
||||
@@ -657,9 +814,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
|
||||
index_t num_workgroups_per_Conv_N_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -667,19 +827,24 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
float ave_time = 0;
|
||||
|
||||
const index_t gdy = arg.num_group_;
|
||||
const index_t num_workgroups_per_Conv_N =
|
||||
arg.a_g_n_k_wos_lengths_[I1] / arg.conv_N_per_block_;
|
||||
const index_t gdz = num_workgroups_per_Conv_N;
|
||||
const index_t gdz = arg.num_workgroups_per_Conv_N_;
|
||||
|
||||
const ADataType* p_a_grid = arg.p_a_grid_;
|
||||
EDataType* p_e_grid = arg.p_e_grid_;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
|
||||
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
|
||||
@@ -722,10 +887,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
p_a_grid,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
@@ -751,6 +916,114 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
float ave_time = 0;
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
// Transpose from NGKHW to NHWGK
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
EDataType* p_e_in_grid = type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
hip_check_error(hipMemsetAsync(p_e_in_grid,
|
||||
0,
|
||||
arg.GetWorkspaceETensorSizeBytes(),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
|
||||
arg.a_in_transpose_desc_) *
|
||||
arg.num_workgroups_per_Conv_N_;
|
||||
|
||||
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
|
||||
|
||||
auto kernel_transpose =
|
||||
kernel_batched_elementwise<GridwiseElementwiseInputTranspose,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<const ADataType*>,
|
||||
ck::Tuple<ADataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
ave_time += launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.a_in_transpose_desc_),
|
||||
make_tuple(arg.a_out_transpose_desc_),
|
||||
make_tuple(arg.p_a_grid_),
|
||||
make_tuple(p_a_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_a_,
|
||||
element_wise::PassThrough{},
|
||||
arg.num_workgroups_per_Conv_N_,
|
||||
std::array<index_t, I1>{
|
||||
static_cast<index_t>(arg.compute_ptr_offset_of_workspace_n_.BatchStrideA_)},
|
||||
std::array<index_t, I1>{
|
||||
static_cast<index_t>(arg.compute_ptr_offset_of_n_.BatchStrideA_)});
|
||||
}
|
||||
ave_time += RunGemm(arg, stream_config);
|
||||
// Transpose from NHWGC to NGCHW
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
|
||||
arg.e_in_transpose_desc_) *
|
||||
arg.num_workgroups_per_Conv_N_;
|
||||
|
||||
const EDataType* p_e_in_grid =
|
||||
type_convert<EDataType*>(arg.p_workspace_) +
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
|
||||
|
||||
EDataType* p_e_out_grid = arg.p_e_grid_;
|
||||
|
||||
auto kernel_transpose =
|
||||
kernel_batched_elementwise<GridwiseElementwiseOutputTranspose,
|
||||
ck::Tuple<NHWGCTransposeDescType>,
|
||||
ck::Tuple<NGCHWTransposeDescType>,
|
||||
ck::Tuple<const EDataType*>,
|
||||
ck::Tuple<EDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
element_wise::PassThrough,
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
ave_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel_transpose,
|
||||
dim3(grid_size),
|
||||
dim3(ElementwiseBlocksize),
|
||||
0,
|
||||
make_tuple(arg.e_in_transpose_desc_),
|
||||
make_tuple(arg.e_out_transpose_desc_),
|
||||
make_tuple(p_e_in_grid),
|
||||
make_tuple(p_e_out_grid),
|
||||
arg.elementwise_block_2_ctile_map_transpose_e_,
|
||||
element_wise::PassThrough{},
|
||||
arg.num_workgroups_per_Conv_N_,
|
||||
std::array<index_t, I1>{
|
||||
static_cast<index_t>(arg.compute_ptr_offset_of_n_.BatchStrideE_)},
|
||||
std::array<index_t, I1>{static_cast<index_t>(
|
||||
arg.compute_ptr_offset_of_workspace_n_.BatchStrideE_)});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
@@ -765,6 +1038,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t ConvG = arg.b_g_k_c_xs_lengths_[0];
|
||||
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
|
||||
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
|
||||
|
||||
@@ -787,7 +1061,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
|
||||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGKHW> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NGKDHW>)
|
||||
{
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
@@ -848,7 +1124,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::GNDHWC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NHWGC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NDHWGC>)
|
||||
is_same_v<ELayout, tensor_layout::convolution::NDHWGC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NGCHW> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NGCDHW>)
|
||||
{
|
||||
// vector store C matrix into global memory
|
||||
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
@@ -874,6 +1152,48 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
|
||||
{
|
||||
if((ConvG * ConvC) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((ConvG * ConvK) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t a_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
const index_t e_spatial_acum = ck::accumulate_n<index_t>(
|
||||
arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
|
||||
|
||||
if(a_spatial_acum % TransposeTransferInScalarPerVectorAligned != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(e_spatial_acum % TransposeTransferOutScalarPerVectorAligned != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!arg.p_workspace_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout
|
||||
<< "Warning: Workspace for "
|
||||
"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument is not "
|
||||
"allocated, use SetWorkSpacePointer."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -998,11 +1318,48 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< CShuffleMXdlPerWavePerShuffle << ", "
|
||||
<< CShuffleNXdlPerWavePerShuffle
|
||||
<< ">";
|
||||
<< CShuffleNXdlPerWavePerShuffle;
|
||||
|
||||
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
|
||||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>()) {
|
||||
str << ", TransposeTransferInScalarPerVectorAligned: "
|
||||
<< TransposeTransferInScalarPerVectorAligned <<", "
|
||||
<< "TransposeTransferOutScalarPerVectorAligned: " << TransposeTransferOutScalarPerVectorAligned;
|
||||
}
|
||||
|
||||
|
||||
str << ">";
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = dynamic_cast<const Argument*>(p_arg);
|
||||
if(arg)
|
||||
{
|
||||
return arg->GetWorkspaceSizeBytes();
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument structure!");
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* p_arg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(p_arg_)
|
||||
{
|
||||
p_arg_->p_workspace_ = p_workspace;
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument structure!");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -1621,6 +1621,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.b_out_transpose_desc_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
@@ -834,6 +834,25 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!arg.p_workspace_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Warning: Workspace for "
|
||||
"DeviceGroupedConvBwdWeight_Xdl_CShuffle::Argument is not "
|
||||
"allocated, use SetWorkSpacePointer."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.b_out_transpose_desc_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
|
||||
@@ -771,12 +771,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
|
||||
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
|
||||
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(ADataType) * a_acum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize();
|
||||
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
|
||||
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(EDataType) * e_accum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
@@ -1293,6 +1297,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!arg.p_workspace_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Warning: Workspace for "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument is not "
|
||||
"allocated, use SetWorkSpacePointer."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.e_in_transpose_desc_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(!valid)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -586,12 +586,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
|
||||
std::size_t GetWorkspaceATensorSizeBytes() const
|
||||
{
|
||||
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
|
||||
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
|
||||
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(ADataType) * a_acum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize();
|
||||
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
|
||||
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
|
||||
return sizeof(EDataType) * e_accum;
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
@@ -1207,6 +1211,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!arg.p_workspace_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Warning: Workspace for "
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3::Argument is not "
|
||||
"allocated, use SetWorkSpacePointer."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
|
||||
arg.e_in_transpose_desc_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// check vector access of E
|
||||
|
||||
Reference in New Issue
Block a user