Grouped conv bwd data NGCHW (#1967)

* Grouped conv bwd data NGCHW

* fixes

* fix

* Improvements

* Fix

* Fix

* add client example
This commit is contained in:
Bartłomiej Kocot
2025-03-17 13:32:00 +01:00
committed by GitHub
parent 52b1cd7780
commit c2e4898b4b
26 changed files with 1351 additions and 71 deletions

View File

@@ -1,11 +1,12 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
@@ -13,7 +14,9 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -202,9 +205,11 @@ template <index_t NDimSpatial,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename AComputeType = ADataType,
typename BComputeType = AComputeType>
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename AComputeType = ADataType,
typename BComputeType = AComputeType,
index_t MaxTransposeTransferInScalarPerVector = 1,
index_t MaxTransposeTransferOutScalarPerVector = 1>
struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
: public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
ALayout, // output image
@@ -237,6 +242,19 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
using ALayoutAfterTranspose =
std::conditional_t<is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NHWGK,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NDHWGK,
ALayout>>;
using ELayoutAfterTranspose =
std::conditional_t<is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NHWGC,
std::conditional_t<is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>(),
tensor_layout::convolution::NDHWGC,
ELayout>>;
using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
ConvBackwardDataSpecialization,
AK1,
@@ -246,9 +264,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
KPerBlock,
DoPadGemmM,
DoPadGemmN,
ALayout,
ALayoutAfterTranspose,
BLayout,
ELayout,
ELayoutAfterTranspose,
true, /*SplitConvN*/
ABDataType,
EDataType>;
@@ -274,7 +292,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
KPerBlock,
DoPadGemmM,
DoPadGemmN,
ALayout,
ALayoutAfterTranspose,
BLayout,
DLayout,
true, /*SplitConvN*/
@@ -374,7 +392,70 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// block-to-e-tile map
using Block2ETileMap =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, MPerBlock>;
static constexpr index_t ClusterLengthMPerBlock =
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t ClusterLengthNPerBlock =
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
static constexpr auto conv_ngchw_to_nhwgc_transformer =
TransformConvNGCHWToNHWGC<ELayout,
BLayout,
ALayout,
NDimSpatial,
NPerBlock / ClusterLengthNPerBlock,
MPerBlock / ClusterLengthMPerBlock>{};
static constexpr index_t TransposeTransferInScalarPerVectorAligned =
std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferInScalarPerVector);
static constexpr index_t TransposeTransferOutScalarPerVectorAligned =
std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferOutScalarPerVector);
using NGCHWTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
using NHWGCTransposeDescType =
remove_cvref_t<decltype(conv_ngchw_to_nhwgc_transformer
.template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
static constexpr index_t ElementwiseBlocksize = ClusterLengthMPerBlock * ClusterLengthNPerBlock;
using GridwiseElementwiseInputTranspose =
GridwiseElementwise<Tuple<NGCHWTransposeDescType>,
Tuple<NHWGCTransposeDescType>,
Tuple<const ADataType*>,
Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
MPerBlock,
NPerBlock / ClusterLengthNPerBlock,
MPerBlock / ClusterLengthMPerBlock,
Sequence<1, 0>,
Sequence<TransposeTransferInScalarPerVectorAligned>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
I1,
I0>;
using GridwiseElementwiseOutputTranspose =
GridwiseElementwise<Tuple<NHWGCTransposeDescType>,
Tuple<NGCHWTransposeDescType>,
Tuple<const EDataType*>,
Tuple<EDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
ElementwiseBlocksize,
NPerBlock,
MPerBlock,
NPerBlock / ClusterLengthNPerBlock,
MPerBlock / ClusterLengthMPerBlock,
Sequence<1, 0>,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
Sequence<TransposeTransferOutScalarPerVectorAligned>,
I0,
I1>;
// Argument
struct Argument : public BaseArgument
{
@@ -409,10 +490,18 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
cde_element_op_{cde_element_op},
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeStrides(e_g_n_c_wis_lengths,
e_g_n_c_wis_strides);
// populate Ds pointer
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
@@ -491,17 +580,18 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
}
ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes};
ConvToGemmBwdDataTransform conv_to_gemm_transform_{
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides_transposed,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
e_g_n_c_wis_strides_transposed,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
tildes};
conv_N_per_block_ = conv_to_gemm_transform_.N_;
@@ -527,7 +617,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
KPerBlock,
DoPadGemmM,
DoPadGemmN,
ALayout,
ALayoutAfterTranspose,
BLayout,
DLayout,
true, /*SplitConvN*/
@@ -535,7 +625,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DDataType>;
ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
a_g_n_k_wos_strides_transposed,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
ds_g_n_c_wis_lengths[i],
@@ -591,12 +681,73 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
}
// A/B/Ds/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides_transposed[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_k_wos_strides[1] * conv_N_per_block_;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_c_wis_strides[1] * conv_N_per_block_;
compute_ptr_offset_of_n_.BatchStrideA_ =
a_g_n_k_wos_strides_transposed[1] * conv_N_per_block_;
compute_ptr_offset_of_n_.BatchStrideE_ =
e_g_n_c_wis_strides_transposed[1] * conv_N_per_block_;
num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_;
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
{
// Use not modified base strides
a_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_);
a_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_);
e_in_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_);
e_out_transpose_desc_ =
conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{
e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)};
compute_ptr_offset_of_workspace_n_.BatchStrideA_ =
a_g_n_k_wos_strides[1] * conv_N_per_block_;
compute_ptr_offset_of_workspace_n_.BatchStrideE_ =
e_g_n_c_wis_strides[1] * conv_N_per_block_;
}
}
std::size_t GetWorkspaceATensorSizeBytes() const
{
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(ADataType) * a_acum;
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(EDataType) * e_accum;
}
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceETensorSizeBytes();
}
else
{
return 0;
}
}
void Print() const
@@ -645,10 +796,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// block-to-e-tile map
std::vector<Block2ETileMap> block_2_etile_map_container_;
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_e_;
NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_;
NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_workspace_n_;
// element-wise op
AElementwiseOp a_element_op_;
@@ -657,9 +814,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
std::array<index_t, NDimSpatial> conv_filter_strides_;
std::array<index_t, NDimSpatial> input_left_pads_;
std::array<index_t, NDimSpatial> input_right_pads_;
index_t num_workgroups_per_Conv_N_;
};
// Invoker
@@ -667,19 +827,24 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
float ave_time = 0;
const index_t gdy = arg.num_group_;
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_k_wos_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdz = num_workgroups_per_Conv_N;
const index_t gdz = arg.num_workgroups_per_Conv_N_;
const ADataType* p_a_grid = arg.p_a_grid_;
EDataType* p_e_grid = arg.p_e_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_);
p_e_grid = type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
}
float ave_time = 0;
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
@@ -722,10 +887,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
p_a_grid,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
p_e_grid,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
@@ -751,6 +916,114 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
return ave_time;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float ave_time = 0;
if(stream_config.log_level_ > 0)
{
arg.Print();
}
// Transpose from NGKHW to NHWGK
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
{
EDataType* p_e_in_grid = type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
const auto clear_workspace = [&]() {
hip_check_error(hipMemsetAsync(p_e_in_grid,
0,
arg.GetWorkspaceETensorSizeBytes(),
stream_config.stream_id_));
};
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
arg.a_in_transpose_desc_) *
arg.num_workgroups_per_Conv_N_;
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
auto kernel_transpose =
kernel_batched_elementwise<GridwiseElementwiseInputTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
I1,
I1>;
ave_time += launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel_transpose,
dim3(grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.a_in_transpose_desc_),
make_tuple(arg.a_out_transpose_desc_),
make_tuple(arg.p_a_grid_),
make_tuple(p_a_out_grid),
arg.elementwise_block_2_ctile_map_transpose_a_,
element_wise::PassThrough{},
arg.num_workgroups_per_Conv_N_,
std::array<index_t, I1>{
static_cast<index_t>(arg.compute_ptr_offset_of_workspace_n_.BatchStrideA_)},
std::array<index_t, I1>{
static_cast<index_t>(arg.compute_ptr_offset_of_n_.BatchStrideA_)});
}
ave_time += RunGemm(arg, stream_config);
// Transpose from NHWGC to NGCHW
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
arg.e_in_transpose_desc_) *
arg.num_workgroups_per_Conv_N_;
const EDataType* p_e_in_grid =
type_convert<EDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(EDataType);
EDataType* p_e_out_grid = arg.p_e_grid_;
auto kernel_transpose =
kernel_batched_elementwise<GridwiseElementwiseOutputTranspose,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<const EDataType*>,
ck::Tuple<EDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
I1,
I1>;
ave_time += launch_and_time_kernel(
stream_config,
kernel_transpose,
dim3(grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.e_in_transpose_desc_),
make_tuple(arg.e_out_transpose_desc_),
make_tuple(p_e_in_grid),
make_tuple(p_e_out_grid),
arg.elementwise_block_2_ctile_map_transpose_e_,
element_wise::PassThrough{},
arg.num_workgroups_per_Conv_N_,
std::array<index_t, I1>{
static_cast<index_t>(arg.compute_ptr_offset_of_n_.BatchStrideE_)},
std::array<index_t, I1>{static_cast<index_t>(
arg.compute_ptr_offset_of_workspace_n_.BatchStrideE_)});
}
return ave_time;
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
@@ -765,6 +1038,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
return false;
}
const index_t ConvG = arg.b_g_k_c_xs_lengths_[0];
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
@@ -787,7 +1061,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
is_same_v<ALayout, tensor_layout::convolution::NDHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NGKHW> ||
is_same_v<ALayout, tensor_layout::convolution::NGKDHW>)
{
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
{
@@ -848,7 +1124,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> ||
is_same_v<ELayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<ELayout, tensor_layout::convolution::NHWGC> ||
is_same_v<ELayout, tensor_layout::convolution::NDHWGC>)
is_same_v<ELayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<ELayout, tensor_layout::convolution::NGCHW> ||
is_same_v<ELayout, tensor_layout::convolution::NGCDHW>)
{
// vector store C matrix into global memory
if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0))
@@ -874,6 +1152,48 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
}
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>())
{
if((ConvG * ConvC) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
if((ConvG * ConvK) % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
const index_t a_spatial_acum = ck::accumulate_n<index_t>(
arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
const index_t e_spatial_acum = ck::accumulate_n<index_t>(
arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>());
if(a_spatial_acum % TransposeTransferInScalarPerVectorAligned != 0)
{
return false;
}
if(e_spatial_acum % TransposeTransferOutScalarPerVectorAligned != 0)
{
return false;
}
if(!arg.p_workspace_)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout
<< "Warning: Workspace for "
"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument is not "
"allocated, use SetWorkSpacePointer."
<< std::endl;
}
return false;
}
}
return true;
}
@@ -998,11 +1318,48 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle
<< ">";
<< CShuffleNXdlPerWavePerShuffle;
if constexpr(is_NGCHW_GKYXC_NGKHW<ELayout, BLayout, ALayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<ELayout, BLayout, ALayout>()) {
str << ", TransposeTransferInScalarPerVectorAligned: "
<< TransposeTransferInScalarPerVectorAligned <<", "
<< "TransposeTransferOutScalarPerVectorAligned: " << TransposeTransferOutScalarPerVectorAligned;
}
str << ">";
return str.str();
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = dynamic_cast<const Argument*>(p_arg);
if(arg)
{
return arg->GetWorkspaceSizeBytes();
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument structure!");
}
void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->p_workspace_ = p_workspace;
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument structure!");
}
};
} // namespace device

View File

@@ -1621,6 +1621,13 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
return false;
}
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
arg.b_out_transpose_desc_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
{
return false;
}
}
return true;

View File

@@ -834,6 +834,25 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
{
return false;
}
if(!arg.p_workspace_)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Warning: Workspace for "
"DeviceGroupedConvBwdWeight_Xdl_CShuffle::Argument is not "
"allocated, use SetWorkSpacePointer."
<< std::endl;
}
return false;
}
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
arg.b_out_transpose_desc_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
{
return false;
}
}
// Gridwise GEMM size

View File

@@ -771,12 +771,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
std::size_t GetWorkspaceATensorSizeBytes() const
{
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(ADataType) * a_acum;
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize();
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(EDataType) * e_accum;
}
std::size_t GetWorkspaceSizeBytes() const
@@ -1293,6 +1297,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
return false;
}
if(!arg.p_workspace_)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Warning: Workspace for "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument is not "
"allocated, use SetWorkSpacePointer."
<< std::endl;
}
return false;
}
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
arg.e_in_transpose_desc_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
}
if(!valid)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -586,12 +586,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
std::size_t GetWorkspaceATensorSizeBytes() const
{
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
const long_index_t a_acum = ck::accumulate_n<long_index_t>(
a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(ADataType) * a_acum;
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
return sizeof(EDataType) * e_out_transpose_desc_.GetElementSpaceSize();
const long_index_t e_accum = ck::accumulate_n<long_index_t>(
e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>());
return sizeof(EDataType) * e_accum;
}
std::size_t GetWorkspaceSizeBytes() const
@@ -1207,6 +1211,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
return false;
}
if(!arg.p_workspace_)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Warning: Workspace for "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3::Argument is not "
"allocated, use SetWorkSpacePointer."
<< std::endl;
}
return false;
}
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
arg.e_in_transpose_desc_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
}
// check vector access of E