Add support for GKCYX grouped conv fwd (#2015)

* Add support for GKCYX grouped conv fwd

* fixes

* fix

* changelog

* Fixes
This commit is contained in:
Bartłomiej Kocot
2025-03-26 21:13:38 +01:00
committed by GitHub
parent fd915b83f7
commit 54c81a1fcf
39 changed files with 1005 additions and 570 deletions

View File

@@ -496,11 +496,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
input_right_pads_{input_right_pads}
{
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeStrides(e_g_n_c_wis_lengths,
e_g_n_c_wis_strides);
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(e_g_n_c_wis_lengths,
e_g_n_c_wis_strides);
// populate Ds pointer
static_for<0, NumDTensor, 1>{}([&](auto i) {

View File

@@ -534,11 +534,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
begin(output_spatial_lengths_));
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
const auto descs =
conv_to_gemm_transformer_v2
@@ -1425,11 +1425,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
// Different data type for A and B is not supported
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
GridwiseElementwiseTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,

View File

@@ -453,11 +453,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
begin(output_spatial_lengths_));
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
b_g_n_c_wis_strides);
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
const auto descs =
conv_to_gemm_transformer
@@ -641,11 +641,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
// Different data type for A and B is not supported
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
GridwiseElementwiseTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,

View File

@@ -314,8 +314,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
// NGCHW is not supported for multiAB
static_assert(!(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>()) ||
static_assert(!(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>()) ||
!(isMultiA || isMultiB));
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
@@ -355,11 +355,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGC,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGC,
ALay>>;
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGC, ALay>>;
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<Layout>();
@@ -373,8 +371,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename BLay>
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
ctc::GKYXC,
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::GKZYXC, BLay>>;
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
@@ -387,11 +391,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGK,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGK,
ELay>>;
std::conditional_t<is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>(), ctc::NDHWGK, ELay>>;
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<Layout>();
@@ -491,6 +493,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
using GKCYXTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
using GKYXCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
using GridwiseElementwiseInputTranspose =
@@ -511,6 +520,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
I1,
I0>;
using GridwiseElementwiseWeightTranspose =
GridwiseElementwise<Tuple<GKCYXTransposeDescType>,
Tuple<GKYXCTransposeDescType>,
Tuple<const BDataType*>,
Tuple<BDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
NPerBlock,
NPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<1>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
I0,
I1>;
using GridwiseElementwiseOutputTranspose =
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
Tuple<NGCHWTransposeDescType>,
@@ -558,14 +585,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
@@ -744,8 +772,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
// Use not modified base strides
a_in_transpose_desc_ =
@@ -755,6 +783,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
b_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
b_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
e_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
@@ -764,6 +799,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
}
@@ -771,25 +808,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
std::size_t GetWorkspaceATensorSizeBytes() const
{
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(ADataType) * a_acum;
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(EDataType) * e_accum;
}
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
// Align to 128B
return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128;
}
else
{
@@ -797,6 +822,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
std::size_t GetWorkspaceBTensorSizeBytes() const
{
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
// Align to 128B
return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128;
}
else
{
return 0;
}
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(EDataType) * e_accum;
}
else
{
return 0;
}
}
std::size_t GetWorkspaceSizeBytes() const
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
GetWorkspaceETensorSizeBytes();
}
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
@@ -849,10 +911,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_e_;
elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_;
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
GKCYXTransposeDescType b_in_transpose_desc_;
GKYXCTransposeDescType b_out_transpose_desc_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
@@ -942,14 +1006,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
else
{
const ADataType* p_a_grid = arg.p_as_grid_.At(I0);
const BDataType* p_b_grid = arg.p_bs_grid_.At(I0);
EDataType* p_e_grid = arg.p_e_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
else if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
(arg.GetWorkspaceATensorSizeBytes() +
arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
@@ -978,8 +1056,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid, // Pass just A descriptor instead of tuple
arg.p_bs_grid_.At(I0), // Pass just B descriptor instead of tuple
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
arg.a_element_op_,
@@ -1009,50 +1087,71 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
float avg_time = 0.f;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
const index_t a_grid_size =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
arg.a_in_transpose_desc_);
const index_t b_grid_size =
(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
? arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
arg.b_in_transpose_desc_)
: 0; // Dont run transpose B if not needed
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
BDataType* p_b_out_grid = type_convert<BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseInputTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough>;
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseInputTranspose,
GridwiseElementwiseWeightTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<GKCYXTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<GKYXCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<const BDataType*>,
ck::Tuple<ADataType*>,
ck::Tuple<BDataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size),
dim3(a_grid_size + b_grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.a_in_transpose_desc_),
make_tuple(arg.b_in_transpose_desc_),
make_tuple(arg.a_out_transpose_desc_),
make_tuple(arg.b_out_transpose_desc_),
make_tuple(arg.p_as_grid_.At(I0)),
make_tuple(arg.p_bs_grid_.At(I0)),
make_tuple(p_a_out_grid),
make_tuple(p_b_out_grid),
arg.elementwise_block_2_ctile_map_transpose_a_,
element_wise::PassThrough{});
arg.elementwise_block_2_ctile_map_transpose_b_,
element_wise::PassThrough{},
a_grid_size);
}
avg_time += RunGemm(arg, stream_config);
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
arg.e_in_transpose_desc_);
const EDataType* p_e_out_grid =
const EDataType* p_e_in_grid =
type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
EDataType* p_e_in_grid = arg.p_e_grid_;
EDataType* p_e_out_grid = arg.p_e_grid_;
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
ck::Tuple<NHWGCTransposeDescType>,
@@ -1069,8 +1168,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
0,
make_tuple(arg.e_in_transpose_desc_),
make_tuple(arg.e_out_transpose_desc_),
make_tuple(p_e_out_grid),
make_tuple(p_e_in_grid),
make_tuple(p_e_out_grid),
arg.elementwise_block_2_ctile_map_transpose_e_,
element_wise::PassThrough{});
}
@@ -1114,12 +1213,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
@@ -1131,11 +1230,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
@@ -1156,10 +1255,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return false;
}
}
if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>())
{
return false;
}
}
if constexpr(NumGroupsToMerge > 1)
@@ -1173,7 +1268,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return false;
}
if constexpr(!(is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() ||
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>()))
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>() ||
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>()))
{
return false;
}
@@ -1194,7 +1291,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// If not possible, check access per G
if(!(ABlockTransferSrcVectorDim == 1 && (C == 1 || NumGroupsToMerge == 1) &&
(is_NSpatialGC_GKSpatial_NSpatialGK<ALayout, BLayout, ELayout>() ||
is_NGCSpatial_GKSpatial_NGKSpatial<ALayout, BLayout, ELayout>()) &&
is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>()) &&
G % ABlockTransferSrcScalarPerVector == 0))
{
return false;
@@ -1212,7 +1310,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
is_same_v<BLayout, ctc::KZYXGC>)
is_same_v<BLayout, ctc::KZYXGC> || is_same_v<BLayout, ctc::GKCX> ||
is_same_v<BLayout, ctc::GKCYX> || is_same_v<BLayout, ctc::GKCZYX>)
{
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
@@ -1270,8 +1369,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
});
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{

View File

@@ -325,9 +325,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGC,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGC,
ALay>>;
@@ -353,8 +353,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static auto
MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>(),
ctc::GKYXC,
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::GKZYXC,
BLay>>;
const auto wei_gemmnraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
conv_to_gemm_transformer.template MakeBDescriptor_N_K<Layout>();
const auto wei_gemmn_gemmk_desc =
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
@@ -377,9 +385,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
namespace ctc = tensor_layout::convolution;
using Layout = std::conditional_t<
is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>(),
is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>(),
ctc::NHWGK,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>(),
std::conditional_t<is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>(),
ctc::NDHWGK,
ELay>>;
@@ -426,6 +434,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
using GKCYXTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
using GKYXCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock;
using GridwiseElementwiseInputTranspose =
@@ -446,6 +461,24 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
I1,
I0>;
using GridwiseElementwiseWeightTranspose =
GridwiseElementwise<Tuple<GKCYXTransposeDescType>,
Tuple<GKYXCTransposeDescType>,
Tuple<const BDataType*>,
Tuple<BDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
NPerBlock,
NPerBlock / ClusterLengthNPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<1>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
I0,
I1>;
using GridwiseElementwiseOutputTranspose =
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
Tuple<NGCHWTransposeDescType>,
@@ -508,12 +541,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
p_b_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides)},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeStrides(
e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides)},
conv_filter_strides_{conv_filter_strides},
conv_filter_dilations_{conv_filter_dilations},
@@ -559,8 +593,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
// Use not modified base strides
a_in_transpose_desc_ =
@@ -570,9 +604,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
a_g_n_c_wis_lengths, a_g_n_c_wis_strides);
b_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
b_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
b_g_k_c_xs_lengths, b_g_k_c_xs_strides);
e_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
e_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides);
@@ -586,25 +629,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
std::size_t GetWorkspaceATensorSizeBytes() const
{
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(ADataType) * a_acum;
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(EDataType) * e_accum;
}
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
// Align to 128B
return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128;
}
else
{
@@ -612,6 +643,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
}
std::size_t GetWorkspaceBTensorSizeBytes() const
{
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
const long_index_t b_acum = ck::accumulate_n<long_index_t>(
b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
// Align to 128B
return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128;
}
else
{
return 0;
}
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
if constexpr(is_NGCHW_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_NGKDHW<ALayout, BLayout, ELayout>())
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(EDataType) * e_accum;
}
else
{
return 0;
}
}
std::size_t GetWorkspaceSizeBytes() const
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
GetWorkspaceETensorSizeBytes();
}
void Print() const
{
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
@@ -661,10 +729,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// block-to-e-tile map
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_e_;
elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_;
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
GKCYXTransposeDescType b_in_transpose_desc_;
GKYXCTransposeDescType b_out_transpose_desc_;
};
// Invoker
@@ -702,18 +772,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const ADataType* p_a_grid = arg.p_a_grid_;
const BDataType* p_b_grid = arg.p_b_grid_;
EDataType* p_e_grid = arg.p_e_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
p_b_grid = type_convert<const BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
p_e_grid =
type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
}
typename GridwiseGemm::Argument gemm_arg{
p_a_grid, arg.p_b_grid_, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1};
p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, I1};
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
@@ -1012,50 +1087,68 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
float avg_time = 0.f;
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
const index_t a_grid_size =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
arg.a_in_transpose_desc_);
const index_t b_grid_size =
arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
arg.b_in_transpose_desc_);
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
BDataType* p_b_out_grid = type_convert<BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseInputTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough>;
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseInputTranspose,
GridwiseElementwiseWeightTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<GKCYXTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<GKYXCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<const BDataType*>,
ck::Tuple<ADataType*>,
ck::Tuple<BDataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size),
dim3(a_grid_size + b_grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.a_in_transpose_desc_),
make_tuple(arg.b_in_transpose_desc_),
make_tuple(arg.a_out_transpose_desc_),
make_tuple(arg.b_out_transpose_desc_),
make_tuple(arg.p_a_grid_),
make_tuple(arg.p_b_grid_),
make_tuple(p_a_out_grid),
make_tuple(p_b_out_grid),
arg.elementwise_block_2_ctile_map_transpose_a_,
element_wise::PassThrough{});
arg.elementwise_block_2_ctile_map_transpose_b_,
element_wise::PassThrough{},
a_grid_size);
}
avg_time += RunGemm(arg, stream_config);
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
arg.e_in_transpose_desc_);
const EDataType* p_e_out_grid =
const EDataType* p_e_in_grid =
type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
EDataType* p_e_in_grid = arg.p_e_grid_;
EDataType* p_e_out_grid = arg.p_e_grid_;
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
ck::Tuple<NHWGCTransposeDescType>,
@@ -1072,8 +1165,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
0,
make_tuple(arg.e_in_transpose_desc_),
make_tuple(arg.e_out_transpose_desc_),
make_tuple(p_e_out_grid),
make_tuple(p_e_in_grid),
make_tuple(p_e_out_grid),
arg.elementwise_block_2_ctile_map_transpose_e_,
element_wise::PassThrough{});
}
@@ -1118,12 +1211,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
@@ -1135,11 +1228,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i];
if(!(X == 1 && LeftPad == 0 && RightPad == 0))
if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
{
return false;
}
@@ -1171,7 +1264,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
is_same_v<BLayout, ctc::KZYXGC>)
is_same_v<BLayout, ctc::KZYXGC> || is_same_v<BLayout, ctc::GKCX> ||
is_same_v<BLayout, ctc::GKCYX> || is_same_v<BLayout, ctc::GKCZYX>)
{
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
@@ -1184,8 +1278,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return false;
}
if constexpr(is_NGCHW_GKYXC_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -59,6 +59,22 @@ constexpr bool is_NGCHW_GKYXC_NGKHW()
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCHW_GKCYX_NGKHW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKCYX> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCHW_NGKHW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCHW> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKHW>;
}
// 3d
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
@@ -84,6 +100,21 @@ constexpr bool is_NGCDHW_GKZYXC_NGKDHW()
is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
is_same_v<WeiLayout, tensor_layout::convolution::GKCZYX> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NGCDHW_NGKDHW()
{
return is_same_v<InLayout, tensor_layout::convolution::NGCDHW> &&
is_same_v<OutLayout, tensor_layout::convolution::NGKDHW>;
}
template <typename InLayout, typename WeiLayout, typename OutLayout>
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
{