mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Enable multiple D for grouped conv fwd large tensors (#2572)
This commit is contained in:
@@ -106,9 +106,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
// offset base pointer for each work-group
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
|
||||
const long_index_t e_group_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
|
||||
const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
|
||||
const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
|
||||
|
||||
const long_index_t e_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
|
||||
@@ -121,7 +123,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
|
||||
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; });
|
||||
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; });
|
||||
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
|
||||
@@ -88,13 +88,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
|
||||
const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
|
||||
const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
|
||||
|
||||
static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor;
|
||||
using DsGridPointer = typename GridwiseGemm::DsGridPointer;
|
||||
DsGridPointer p_ds_grid_grp{};
|
||||
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_group_offset[i]; });
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i];
|
||||
});
|
||||
|
||||
const long_index_t a_group_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
|
||||
@@ -168,13 +170,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
|
||||
const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
|
||||
const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
|
||||
|
||||
static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor;
|
||||
using DsGridPointer = typename GridwiseGemm::DsGridPointer;
|
||||
DsGridPointer p_ds_grid_grp{};
|
||||
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_group_offset[i]; });
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i];
|
||||
});
|
||||
|
||||
const long_index_t a_group_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
|
||||
|
||||
@@ -63,11 +63,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_group_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
|
||||
const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
|
||||
const long_index_t e_group_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
|
||||
|
||||
const long_index_t a_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
|
||||
const long_index_t e_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
|
||||
|
||||
@@ -89,10 +91,18 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
group_id = index_t((left + right) / 2);
|
||||
}
|
||||
|
||||
using DsPointer = decltype(gemm_desc_kernel_args[Number<0>{}].ds_ptr_);
|
||||
DsPointer p_ds_grid_grp;
|
||||
static constexpr index_t NumDTensor = DsPointer::Size();
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
p_ds_grid_grp(i) =
|
||||
gemm_desc_kernel_args[group_id].ds_ptr_[i] + ds_group_offset[i] + ds_n_offset[i];
|
||||
});
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
|
||||
gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset,
|
||||
gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset,
|
||||
Tuple<>{},
|
||||
p_ds_grid_grp,
|
||||
gemm_desc_kernel_args[group_id].e_ptr_ + e_group_offset + e_n_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
@@ -100,7 +110,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
c_element_op,
|
||||
gemm_desc_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_desc_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
Tuple<>{},
|
||||
gemm_desc_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_desc_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_desc_kernel_args[group_id].block_2_etile_map_);
|
||||
#else
|
||||
@@ -259,18 +269,44 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
static auto CastDsPointers(const std::array<const void*, NumDTensor>& p_ds)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
return static_cast<const DDataType*>(p_ds[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
using DsPointer = decltype(CastDsPointers(std::array<const void*, NumDTensor>{}));
|
||||
// desc for problem definition
|
||||
constexpr static ConvToGemmFwdTransformerIndexT dummy_conv_to_gemm_transformer;
|
||||
using AGridDesc_M_K =
|
||||
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
|
||||
using BGridDesc_N_K =
|
||||
remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
|
||||
using DsGridDesc_M_N =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
|
||||
using EGridDesc_M_N =
|
||||
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
|
||||
|
||||
static auto
|
||||
GenerateConvToGemmTransforms(ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformer_base,
|
||||
const ADataType* a_grid_ptr_base,
|
||||
DsPointer ds_grid_ptr_base,
|
||||
EDataType* c_grid_ptr_base)
|
||||
{
|
||||
// Max number of splits
|
||||
@@ -279,11 +315,13 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
// Arrays to store transformers with smaller descs than 2GB
|
||||
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformers_arr;
|
||||
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs_arr;
|
||||
Array<DsPointer, MaxGemmsNum> ds_grid_ptrs_arr;
|
||||
Array<EDataType*, MaxGemmsNum> c_grid_ptrs_arr;
|
||||
// Queue for spliting
|
||||
std::queue<ConvToGemmFwdTransformerLongIndexT> conv_to_gemm_transformers_queue(
|
||||
{conv_to_gemm_transformer_base});
|
||||
std::queue<const ADataType*> a_grid_ptrs_queue({a_grid_ptr_base});
|
||||
std::queue<DsPointer> ds_grid_ptrs_queue({ds_grid_ptr_base});
|
||||
std::queue<EDataType*> c_grid_ptrs_queue({c_grid_ptr_base});
|
||||
|
||||
index_t gemms_number = 0;
|
||||
@@ -300,6 +338,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
// Get transformer from the queue
|
||||
const auto& conv_to_gemm_transformer = conv_to_gemm_transformers_queue.front();
|
||||
const ADataType* a_grid_ptr = a_grid_ptrs_queue.front();
|
||||
DsPointer ds_grid_ptr = ds_grid_ptrs_queue.front();
|
||||
EDataType* c_grid_ptr = c_grid_ptrs_queue.front();
|
||||
|
||||
// Check if convolution not exceed 2GB
|
||||
@@ -308,8 +347,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
// If yes, push into result array
|
||||
conv_to_gemm_transformers_arr(gemms_number) =
|
||||
ConvToGemmFwdTransformerIndexT{conv_to_gemm_transformer};
|
||||
a_grid_ptrs_arr(gemms_number) = a_grid_ptr;
|
||||
c_grid_ptrs_arr(gemms_number) = c_grid_ptr;
|
||||
a_grid_ptrs_arr(gemms_number) = a_grid_ptr;
|
||||
ds_grid_ptrs_arr(gemms_number) = ds_grid_ptr;
|
||||
c_grid_ptrs_arr(gemms_number) = c_grid_ptr;
|
||||
gemms_number++;
|
||||
}
|
||||
else
|
||||
@@ -318,19 +358,23 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformers_left_part,
|
||||
conv_to_gemm_transformers_right_part;
|
||||
const ADataType* a_grid_right_ptr;
|
||||
DsPointer ds_grid_right_ptr;
|
||||
EDataType* c_grid_right_ptr;
|
||||
|
||||
ck::tie(conv_to_gemm_transformers_left_part,
|
||||
conv_to_gemm_transformers_right_part,
|
||||
a_grid_right_ptr,
|
||||
ds_grid_right_ptr,
|
||||
c_grid_right_ptr) =
|
||||
conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, c_grid_ptr);
|
||||
conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, ds_grid_ptr, c_grid_ptr);
|
||||
|
||||
conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_left_part);
|
||||
conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_right_part);
|
||||
// Left offsets remain the same
|
||||
a_grid_ptrs_queue.push(a_grid_ptr);
|
||||
a_grid_ptrs_queue.push(a_grid_right_ptr);
|
||||
ds_grid_ptrs_queue.push(ds_grid_ptr);
|
||||
ds_grid_ptrs_queue.push(ds_grid_right_ptr);
|
||||
c_grid_ptrs_queue.push(c_grid_ptr);
|
||||
c_grid_ptrs_queue.push(c_grid_right_ptr);
|
||||
split_numbers++;
|
||||
@@ -338,6 +382,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
// Remove from the queue
|
||||
conv_to_gemm_transformers_queue.pop();
|
||||
a_grid_ptrs_queue.pop();
|
||||
ds_grid_ptrs_queue.pop();
|
||||
c_grid_ptrs_queue.pop();
|
||||
}
|
||||
|
||||
@@ -345,6 +390,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
|
||||
return ck::make_tuple(conv_to_gemm_transformers_arr,
|
||||
a_grid_ptrs_arr,
|
||||
ds_grid_ptrs_arr,
|
||||
c_grid_ptrs_arr,
|
||||
gemms_number,
|
||||
is_split_valid);
|
||||
@@ -375,6 +421,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
@@ -388,11 +437,14 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
// pointers
|
||||
const ADataType* a_ptr_;
|
||||
const BDataType* b_ptr_;
|
||||
DsPointer ds_ptr_;
|
||||
EDataType* e_ptr_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
@@ -405,16 +457,16 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
{
|
||||
Argument(const void* p_a,
|
||||
const void* p_b,
|
||||
const std::array<const void*, NumDTensor>& /*p_ds*/,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
/*ds_g_n_k_wos_lengths*/,
|
||||
ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
/*ds_g_n_k_wos_strides*/,
|
||||
ds_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
|
||||
@@ -434,6 +486,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
|
||||
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
@@ -441,94 +495,105 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
if constexpr(NumDTensor == 0)
|
||||
// Perform grouped gemm, generate array of tranformer for convolution
|
||||
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformer_arr;
|
||||
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs;
|
||||
Array<DsPointer, MaxGemmsNum> ds_grid_ptrs;
|
||||
Array<EDataType*, MaxGemmsNum> c_grid_ptrs;
|
||||
|
||||
DsPointer p_ds_casted = CastDsPointers(p_ds);
|
||||
|
||||
ck::tie(conv_to_gemm_transformer_arr,
|
||||
a_grid_ptrs,
|
||||
ds_grid_ptrs,
|
||||
c_grid_ptrs,
|
||||
gemms_count_,
|
||||
is_split_valid_) =
|
||||
GenerateConvToGemmTransforms(
|
||||
ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_,
|
||||
a_g_n_c_wis_strides_,
|
||||
b_g_k_c_xs_lengths_,
|
||||
b_g_k_c_xs_strides_,
|
||||
e_g_n_k_wos_lengths_,
|
||||
e_g_n_k_wos_strides_,
|
||||
conv_filter_strides_,
|
||||
conv_filter_dilations_,
|
||||
input_left_pads_,
|
||||
input_right_pads_},
|
||||
static_cast<const ADataType*>(p_a),
|
||||
p_ds_casted,
|
||||
static_cast<EDataType*>(p_e));
|
||||
|
||||
grid_size_ = 0;
|
||||
valid_gemms_count_ = 0;
|
||||
|
||||
if(is_split_valid_)
|
||||
{
|
||||
// Perform grouped gemm, generate array of tranformer for convolution
|
||||
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformer_arr;
|
||||
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs;
|
||||
Array<EDataType*, MaxGemmsNum> c_grid_ptrs;
|
||||
|
||||
ck::tie(conv_to_gemm_transformer_arr,
|
||||
a_grid_ptrs,
|
||||
c_grid_ptrs,
|
||||
gemms_count_,
|
||||
is_split_valid_) =
|
||||
GenerateConvToGemmTransforms(
|
||||
ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_,
|
||||
a_g_n_c_wis_strides_,
|
||||
b_g_k_c_xs_lengths_,
|
||||
b_g_k_c_xs_strides_,
|
||||
e_g_n_k_wos_lengths_,
|
||||
e_g_n_k_wos_strides_,
|
||||
conv_filter_strides_,
|
||||
conv_filter_dilations_,
|
||||
input_left_pads_,
|
||||
input_right_pads_},
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<EDataType*>(p_e));
|
||||
|
||||
grid_size_ = 0;
|
||||
valid_gemms_count_ = 0;
|
||||
|
||||
if(is_split_valid_)
|
||||
// Create GemmArg for each gemm(conv)
|
||||
for(index_t i = 0; i < gemms_count_; i++)
|
||||
{
|
||||
// Create GemmArg for each gemm(conv)
|
||||
for(index_t i = 0; i < gemms_count_; i++)
|
||||
const AGridDesc_M_K a_grid_desc_m_k{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(
|
||||
conv_to_gemm_transformer_arr[i])};
|
||||
const BGridDesc_N_K b_grid_desc_n_k{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(
|
||||
conv_to_gemm_transformer_arr[i])};
|
||||
const auto e_grid_desc_m_n =
|
||||
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_arr[i]);
|
||||
|
||||
const auto ds_grid_desc_m_n =
|
||||
generate_tuple([&](auto) { return e_grid_desc_m_n; }, Number<NumDTensor>{});
|
||||
|
||||
const auto block_2_etile_map =
|
||||
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
|
||||
|
||||
const index_t grid_size_grp =
|
||||
block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
|
||||
|
||||
const index_t BlockStart = grid_size_;
|
||||
const index_t BlockEnd = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
|
||||
b_grid_desc_n_k,
|
||||
ds_grid_desc_m_n,
|
||||
e_grid_desc_m_n,
|
||||
block_2_etile_map))
|
||||
{
|
||||
const AGridDesc_M_K a_grid_desc_m_k{
|
||||
DeviceOp::MakeAGridDescriptor_M_K<ALayout>(
|
||||
conv_to_gemm_transformer_arr[i])};
|
||||
const BGridDesc_N_K b_grid_desc_n_k{
|
||||
DeviceOp::MakeBGridDescriptor_N_K<BLayout>(
|
||||
conv_to_gemm_transformer_arr[i])};
|
||||
const auto e_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
|
||||
conv_to_gemm_transformer_arr[i]);
|
||||
gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{
|
||||
a_grid_ptrs[i],
|
||||
static_cast<const BDataType*>(p_b),
|
||||
ds_grid_ptrs[i],
|
||||
c_grid_ptrs[i],
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k),
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k),
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n),
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n),
|
||||
block_2_etile_map,
|
||||
BlockStart,
|
||||
BlockEnd};
|
||||
|
||||
const auto block_2_etile_map =
|
||||
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
|
||||
|
||||
const index_t grid_size_grp =
|
||||
block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
|
||||
|
||||
const index_t BlockStart = grid_size_;
|
||||
const index_t BlockEnd = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
|
||||
b_grid_desc_n_k,
|
||||
Tuple<>{},
|
||||
e_grid_desc_m_n,
|
||||
block_2_etile_map))
|
||||
{
|
||||
|
||||
gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{
|
||||
a_grid_ptrs[i],
|
||||
static_cast<const BDataType*>(p_b),
|
||||
c_grid_ptrs[i],
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k),
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k),
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n),
|
||||
block_2_etile_map,
|
||||
BlockStart,
|
||||
BlockEnd};
|
||||
|
||||
valid_gemms_count_++;
|
||||
}
|
||||
valid_gemms_count_++;
|
||||
}
|
||||
// N is the same for all convs
|
||||
conv_N_per_block_ = static_cast<index_t>(conv_to_gemm_transformer_arr[I0].N_);
|
||||
}
|
||||
|
||||
// Strides for G and N remain the same
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
// N is the same for all convs
|
||||
conv_N_per_block_ = static_cast<index_t>(conv_to_gemm_transformer_arr[I0].N_);
|
||||
}
|
||||
|
||||
// Strides for G and N remain the same
|
||||
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
|
||||
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
|
||||
});
|
||||
}
|
||||
|
||||
void Print() const
|
||||
@@ -558,8 +623,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
bool is_split_valid_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_groups_;
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_groups_;
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_n_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
@@ -571,6 +636,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
|
||||
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
|
||||
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
|
||||
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
|
||||
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
|
||||
std::array<long_index_t, NDimSpatial> conv_filter_strides_;
|
||||
@@ -584,63 +651,55 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
{
|
||||
float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if constexpr(NumDTensor == 0)
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
const index_t num_workgroups_per_Conv_N =
|
||||
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
|
||||
const index_t num_workgroups_per_Conv_N =
|
||||
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
|
||||
|
||||
const index_t gdx = arg.grid_size_;
|
||||
const index_t gdy = arg.num_group_;
|
||||
const index_t gdz = num_workgroups_per_Conv_N;
|
||||
const index_t gdx = arg.grid_size_;
|
||||
const index_t gdy = arg.num_group_;
|
||||
const index_t gdz = num_workgroups_per_Conv_N;
|
||||
|
||||
// K is constant for all gemms
|
||||
const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
|
||||
arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
// K is constant for all gemms
|
||||
const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
|
||||
arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
const auto kernel =
|
||||
kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
MaxGemmsNum,
|
||||
GemmArgs,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
has_main_loop>;
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
MaxGemmsNum,
|
||||
GemmArgs,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.gemm_desc_kernel_args_,
|
||||
arg.gemms_count_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
};
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.gemm_desc_kernel_args_,
|
||||
arg.gemms_count_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_groups_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0.f;
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -657,9 +716,26 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
|
||||
const long_index_t K = arg.b_g_k_c_xs_lengths_[I1];
|
||||
const long_index_t C = arg.b_g_k_c_xs_lengths_[I2];
|
||||
// Move this to runtime check to align Conv instances
|
||||
// with Conv Multiple D instances
|
||||
if constexpr(NumDTensor != 0)
|
||||
|
||||
bool ds_valid = true;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
for(int d = 0; d < NDimSpatial + I3; d++)
|
||||
{
|
||||
if(arg.ds_g_n_k_wos_strides_[i][d] != arg.e_g_n_k_wos_strides_[d])
|
||||
{
|
||||
ds_valid = false;
|
||||
}
|
||||
if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d])
|
||||
{
|
||||
ds_valid = false;
|
||||
}
|
||||
}
|
||||
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
static_assert(is_same_v<DDataType, EDataType>);
|
||||
});
|
||||
|
||||
if(!ds_valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user