Add grouped conv fwd bias relu instances (#2179)

* Add grouped conv fwd bias relu instances

* fixes

* fix
This commit is contained in:
Bartłomiej Kocot
2025-05-09 22:52:34 +02:00
committed by GitHub
parent 6b1a339b6f
commit 6fddb5708c
33 changed files with 2477 additions and 550 deletions

View File

@@ -279,9 +279,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static constexpr bool isMultiD = DsDataType::Size() > 0;
static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD;
// multi ABD not supported
static_assert(!isMultiABD, "Multi A, Mutli B and Multi D are not supported");
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
static constexpr index_t NumDTensor = DsDataType::Size();
@@ -1080,91 +1077,96 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float avg_time = 0.f;
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
if constexpr(!isMultiABD)
{
const index_t a_grid_size =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
arg.a_in_transpose_desc_);
const index_t b_grid_size =
arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
arg.b_in_transpose_desc_);
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t a_grid_size =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
arg.a_in_transpose_desc_);
const index_t b_grid_size =
arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
arg.b_in_transpose_desc_);
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
BDataType* p_b_out_grid = type_convert<BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_);
BDataType* p_b_out_grid =
type_convert<BDataType*>(arg.p_workspace_) +
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseInputTranspose,
GridwiseElementwiseWeightTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<GKCYXTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<GKYXCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<const BDataType*>,
ck::Tuple<ADataType*>,
ck::Tuple<BDataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,
element_wise::PassThrough>;
auto kernel_transpose =
kernel_elementwise_dual<GridwiseElementwiseInputTranspose,
GridwiseElementwiseWeightTranspose,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<GKCYXTransposeDescType>,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<GKYXCTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<const BDataType*>,
ck::Tuple<ADataType*>,
ck::Tuple<BDataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(a_grid_size + b_grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.a_in_transpose_desc_),
make_tuple(arg.b_in_transpose_desc_),
make_tuple(arg.a_out_transpose_desc_),
make_tuple(arg.b_out_transpose_desc_),
make_tuple(arg.p_a_grid_),
make_tuple(arg.p_b_grid_),
make_tuple(p_a_out_grid),
make_tuple(p_b_out_grid),
arg.elementwise_block_2_ctile_map_transpose_a_,
arg.elementwise_block_2_ctile_map_transpose_b_,
element_wise::PassThrough{},
a_grid_size);
avg_time +=
launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(a_grid_size + b_grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.a_in_transpose_desc_),
make_tuple(arg.b_in_transpose_desc_),
make_tuple(arg.a_out_transpose_desc_),
make_tuple(arg.b_out_transpose_desc_),
make_tuple(arg.p_a_grid_),
make_tuple(arg.p_b_grid_),
make_tuple(p_a_out_grid),
make_tuple(p_b_out_grid),
arg.elementwise_block_2_ctile_map_transpose_a_,
arg.elementwise_block_2_ctile_map_transpose_b_,
element_wise::PassThrough{},
a_grid_size);
}
avg_time += RunGemm(arg, stream_config);
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
arg.e_in_transpose_desc_);
const EDataType* p_e_in_grid =
type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
EDataType* p_e_out_grid = arg.p_e_grid_;
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<const EDataType*>,
ck::Tuple<EDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time +=
launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.e_in_transpose_desc_),
make_tuple(arg.e_out_transpose_desc_),
make_tuple(p_e_in_grid),
make_tuple(p_e_out_grid),
arg.elementwise_block_2_ctile_map_transpose_e_,
element_wise::PassThrough{});
}
}
avg_time += RunGemm(arg, stream_config);
if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<ALayout, BLayout, ELayout>())
{
const index_t grid_size =
arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize(
arg.e_in_transpose_desc_);
const EDataType* p_e_in_grid =
type_convert<EDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) /
sizeof(EDataType);
EDataType* p_e_out_grid = arg.p_e_grid_;
auto kernel_transpose = kernel_elementwise<GridwiseElementwiseOutputTranspose,
ck::Tuple<NHWGCTransposeDescType>,
ck::Tuple<NGCHWTransposeDescType>,
ck::Tuple<const EDataType*>,
ck::Tuple<EDataType*>,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size),
dim3(ElementwiseBlocksize),
0,
make_tuple(arg.e_in_transpose_desc_),
make_tuple(arg.e_out_transpose_desc_),
make_tuple(p_e_in_grid),
make_tuple(p_e_out_grid),
arg.elementwise_block_2_ctile_map_transpose_e_,
element_wise::PassThrough{});
}
return avg_time;
}
@@ -1182,6 +1184,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
// Move this to runtime check to align Conv instances
// with Conv Multiple D instances
if constexpr(isMultiABD)
{
return false;
}
// check device
if(get_device_name() == "gfx908")

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -192,7 +192,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t MaxGemmsNum = 32;
static_assert(NumDTensor == 0, "MultiD not supported.");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
@@ -440,89 +439,94 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
// Perform grouped gemm, generate array of tranformer for convolution
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformer_arr;
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs;
Array<EDataType*, MaxGemmsNum> c_grid_ptrs;
ck::tie(conv_to_gemm_transformer_arr,
a_grid_ptrs,
c_grid_ptrs,
gemms_count_,
is_split_valid_) =
GenerateConvToGemmTransforms(
ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_,
a_g_n_c_wis_strides_,
b_g_k_c_xs_lengths_,
b_g_k_c_xs_strides_,
e_g_n_k_wos_lengths_,
e_g_n_k_wos_strides_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_},
static_cast<const ADataType*>(p_a),
static_cast<EDataType*>(p_e));
grid_size_ = 0;
valid_gemms_count_ = 0;
if(is_split_valid_)
if constexpr(NumDTensor == 0)
{
// Create GemmArg for each gemm(conv)
for(index_t i = 0; i < gemms_count_; i++)
// Perform grouped gemm, generate array of tranformer for convolution
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformer_arr;
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs;
Array<EDataType*, MaxGemmsNum> c_grid_ptrs;
ck::tie(conv_to_gemm_transformer_arr,
a_grid_ptrs,
c_grid_ptrs,
gemms_count_,
is_split_valid_) =
GenerateConvToGemmTransforms(
ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_,
a_g_n_c_wis_strides_,
b_g_k_c_xs_lengths_,
b_g_k_c_xs_strides_,
e_g_n_k_wos_lengths_,
e_g_n_k_wos_strides_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_},
static_cast<const ADataType*>(p_a),
static_cast<EDataType*>(p_e));
grid_size_ = 0;
valid_gemms_count_ = 0;
if(is_split_valid_)
{
const AGridDesc_M_K a_grid_desc_m_k{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(
conv_to_gemm_transformer_arr[i])};
const BGridDesc_N_K b_grid_desc_n_k{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(
conv_to_gemm_transformer_arr[i])};
const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_arr[i]);
const auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
const index_t grid_size_grp =
block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
Tuple<>{},
e_grid_desc_m_n,
block_2_etile_map))
// Create GemmArg for each gemm(conv)
for(index_t i = 0; i < gemms_count_; i++)
{
const AGridDesc_M_K a_grid_desc_m_k{
DeviceOp::MakeAGridDescriptor_M_K<ALayout>(
conv_to_gemm_transformer_arr[i])};
const BGridDesc_N_K b_grid_desc_n_k{
DeviceOp::MakeBGridDescriptor_N_K<BLayout>(
conv_to_gemm_transformer_arr[i])};
const auto e_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
conv_to_gemm_transformer_arr[i]);
gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{
a_grid_ptrs[i],
static_cast<const BDataType*>(p_b),
c_grid_ptrs[i],
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k),
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k),
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n),
block_2_etile_map,
BlockStart,
BlockEnd};
const auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
valid_gemms_count_++;
const index_t grid_size_grp =
block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
Tuple<>{},
e_grid_desc_m_n,
block_2_etile_map))
{
gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{
a_grid_ptrs[i],
static_cast<const BDataType*>(p_b),
c_grid_ptrs[i],
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k),
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k),
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n),
block_2_etile_map,
BlockStart,
BlockEnd};
valid_gemms_count_++;
}
}
// N is the same for all convs
conv_N_per_block_ = static_cast<index_t>(conv_to_gemm_transformer_arr[I0].N_);
}
// N is the same for all convs
conv_N_per_block_ = static_cast<index_t>(conv_to_gemm_transformer_arr[I0].N_);
// Strides for G and N remain the same
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
}
// Strides for G and N remain the same
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
}
void Print() const
@@ -578,55 +582,63 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
{
float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
if constexpr(NumDTensor == 0)
{
arg.Print();
}
if(stream_config.log_level_ > 0)
{
arg.Print();
}
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdx = arg.grid_size_;
const index_t gdy = arg.num_group_;
const index_t gdz = num_workgroups_per_Conv_N;
const index_t gdx = arg.grid_size_;
const index_t gdy = arg.num_group_;
const index_t gdz = num_workgroups_per_Conv_N;
// K is constant for all gemms
const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2);
// K is constant for all gemms
const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle<
GridwiseGemm,
MaxGemmsNum,
GemmArgs,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
has_main_loop>;
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel =
kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle<
GridwiseGemm,
MaxGemmsNum,
GemmArgs,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.gemm_desc_kernel_args_,
arg.gemms_count_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
};
return launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.gemm_desc_kernel_args_,
arg.gemms_count_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
else
{
return launch_kernel(integral_constant<bool, false>{});
return 0.f;
}
}
@@ -643,6 +655,12 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
const long_index_t K = arg.b_g_k_c_xs_lengths_[I1];
const long_index_t C = arg.b_g_k_c_xs_lengths_[I2];
// Move this to runtime check to align Conv instances
// with Conv Multiple D instances
if constexpr(NumDTensor != 0)
{
return false;
}
// Check if all descs are valid
if(!(arg.is_split_valid_ && arg.gemms_count_ == arg.valid_gemms_count_))