Revert Support access per groups and filter2x3 in grouped conv fwd (#1382) (#1406)

This commit is contained in:
Bartłomiej Kocot
2024-07-22 14:21:24 +02:00
committed by GitHub
parent 8c90f25be3
commit 5d8c3d8190
13 changed files with 43 additions and 600 deletions

View File

@@ -86,6 +86,7 @@ __global__ void
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t groups_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
@@ -100,11 +101,14 @@ __global__ void
defined(__gfx94__))
// offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const long_index_t e_group_offset =
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
const auto& ds_batch_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
@@ -117,14 +121,14 @@ __global__ void
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; });
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
if constexpr(isMultiA || isMultiB)
{
AsPointer p_as_grid_grp;
BsPointer p_bs_grid_grp;
const auto& as_group_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx);
const auto& as_batch_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx);
// compute_ptr_offset_of_n_ not need BatchStrideB so
// in case of MultiA is false but isMultiB is true
@@ -135,27 +139,27 @@ __global__ void
static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
static_for<0, NumATensor, 1>{}([&](auto i) {
p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + as_n_offset[i];
p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + as_n_offset[i];
});
}
else
{
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
static_for<0, 1, 1>{}(
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + a_n_offset; });
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + a_n_offset; });
}
const auto& bs_group_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx);
const auto& bs_batch_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx);
static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
static_for<0, NumBTensor, 1>{}(
[&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_group_offset[i]; });
[&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid_grp,
p_bs_grid_grp,
p_ds_grid_grp,
p_e_grid + e_group_offset + e_n_offset,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
@@ -168,19 +172,19 @@ __global__ void
}
else
{
const long_index_t a_group_offset =
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_group_offset =
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid + a_group_offset + a_n_offset,
p_bs_grid + b_group_offset,
p_as_grid + a_batch_offset + a_n_offset,
p_bs_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_group_offset + e_n_offset,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
@@ -196,6 +200,7 @@ __global__ void
ignore = p_bs_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = groups_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
@@ -282,8 +287,7 @@ template <index_t NDimSpatial,
// in tuple for MultiAB), unpack if tuple was
// passed
typename BComputeDataType = AComputeDataType,
LoopScheduler LoopSched = make_default_loop_scheduler(),
index_t NumGroupsToMerge = 1>
LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
ALayout,
@@ -302,8 +306,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
static_assert(NumGroupsToMerge >= 1);
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
@@ -320,8 +322,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ConvForwardSpecialization,
true /*SplitN*/,
ALayout,
ELayout,
NumGroupsToMerge>;
ELayout>;
static constexpr auto matrix_padder =
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
@@ -520,8 +521,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
static_for<0, NumATensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_.BatchStrideA_(i) =
a_g_n_c_wis_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_groups_.BatchStrideA_(i) = a_g_n_c_wis_strides[0];
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
@@ -549,8 +549,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_.BatchStrideB_(i) =
b_g_k_c_xs_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_groups_.BatchStrideB_(i) = b_g_k_c_xs_strides[0];
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
// It is possible that one of the AB is a pointer and one is a tuple.
@@ -570,10 +569,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
compute_ptr_offset_of_groups_.BatchStrideA_ =
a_g_n_c_wis_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_groups_.BatchStrideB_ =
b_g_k_c_xs_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
// p_as and p_bs are pointers
@@ -590,8 +587,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
// D batch stride
compute_ptr_offset_of_groups_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][0] * NumGroupsToMerge;
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
@@ -610,7 +606,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ds_grid_desc_m_n_(i) =
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
});
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
// populate desc for Ds/E
@@ -734,8 +730,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const index_t gdy = arg.num_group_ / NumGroupsToMerge;
const index_t gdz = num_workgroups_per_Conv_N;
const index_t gdy = arg.num_group_ * num_workgroups_per_Conv_N;
const index_t gdz = 1;
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
@@ -784,6 +780,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
@@ -827,6 +824,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_g_n_c_wis_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
@@ -858,10 +856,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace ctc = tensor_layout::convolution;
const index_t G = arg.b_g_k_c_xs_lengths_[I0];
const index_t K = arg.b_g_k_c_xs_lengths_[I1];
const index_t C = arg.b_g_k_c_xs_lengths_[I2];
// check device
if(get_device_name() == "gfx908")
{
@@ -910,42 +904,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
}
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3)
{
if(C != 1)
{
return false;
}
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3];
if(filter_spatial_dim != I3)
{
return false;
}
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
{
return false;
}
}
if constexpr(NumGroupsToMerge > 1)
{
if(!(C == 1))
{
return false;
}
if(G % NumGroupsToMerge != 0)
{
return false;
}
if constexpr(!is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>())
{
return false;
}
}
// check vector access of A
// FIXME: layout
@@ -955,16 +913,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
is_same_v<ALayout, ctc::NDHWGC>)
{
// Check access per C
const index_t C = arg.a_g_n_c_wis_lengths_[2];
if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
{
// If not possible, check access per G
if(!(ABlockTransferSrcVectorDim == 1 && C == 1 &&
is_NSpatialGK_GKSpatial_NSpatialGC<ALayout, BLayout, ELayout>() &&
G % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
return false;
}
}
else
@@ -981,6 +934,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<BLayout, ctc::KZYXGC>)
{
const index_t C = arg.b_g_k_c_xs_lengths_[2];
if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
{
return false;
@@ -1004,6 +959,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::G_K>)
{
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
valid = false;
@@ -1048,6 +1005,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
is_same_v<ELayout, ctc::NDHWGK>)
{
const index_t K = arg.e_g_n_k_wos_lengths_[2];
if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
@@ -1198,8 +1157,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<< BBlockTransferSrcScalarPerVector << ", "
<< CDEBlockTransferScalarPerVector_NPerBlock << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< NumGroupsToMerge
<< CShuffleNXdlPerWavePerShuffle
<< ">";
// clang-format on