Enable multiple D for grouped conv fwd large tensors (#2572)

This commit is contained in:
Bartłomiej Kocot
2025-07-28 22:39:07 +02:00
committed by GitHub
parent 0782ee8eb3
commit 5b244105d9
7 changed files with 377 additions and 148 deletions

View File

@@ -106,9 +106,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
// offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const long_index_t e_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
@@ -121,7 +123,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; });
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; });
if constexpr(isMultiA || isMultiB)
{

View File

@@ -88,13 +88,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor;
using DsGridPointer = typename GridwiseGemm::DsGridPointer;
DsGridPointer p_ds_grid_grp{};
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_group_offset[i]; });
static_for<0, NumDTensor, 1>{}([&](auto i) {
p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i];
});
const long_index_t a_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
@@ -168,13 +170,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor;
using DsGridPointer = typename GridwiseGemm::DsGridPointer;
DsGridPointer p_ds_grid_grp{};
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_group_offset[i]; });
static_for<0, NumDTensor, 1>{}([&](auto i) {
p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i];
});
const long_index_t a_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));

View File

@@ -63,11 +63,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
const long_index_t e_group_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
@@ -89,10 +91,18 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
group_id = index_t((left + right) / 2);
}
using DsPointer = decltype(gemm_desc_kernel_args[Number<0>{}].ds_ptr_);
DsPointer p_ds_grid_grp;
static constexpr index_t NumDTensor = DsPointer::Size();
static_for<0, NumDTensor, 1>{}([&](auto i) {
p_ds_grid_grp(i) =
gemm_desc_kernel_args[group_id].ds_ptr_[i] + ds_group_offset[i] + ds_n_offset[i];
});
GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset,
gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset,
Tuple<>{},
p_ds_grid_grp,
gemm_desc_kernel_args[group_id].e_ptr_ + e_group_offset + e_n_offset,
p_shared,
a_element_op,
@@ -100,7 +110,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
c_element_op,
gemm_desc_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_desc_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
Tuple<>{},
gemm_desc_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_kernel_args[group_id].block_2_etile_map_);
#else
@@ -259,18 +269,44 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
return out_gemmm_gemmn_desc;
}
static auto
MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
},
Number<NumDTensor>{});
}
static auto CastDsPointers(const std::array<const void*, NumDTensor>& p_ds)
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(p_ds[i]);
},
Number<NumDTensor>{});
}
using DsPointer = decltype(CastDsPointers(std::array<const void*, NumDTensor>{}));
// desc for problem definition
constexpr static ConvToGemmFwdTransformerIndexT dummy_conv_to_gemm_transformer;
using AGridDesc_M_K =
remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(dummy_conv_to_gemm_transformer))>;
using BGridDesc_N_K =
remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>(dummy_conv_to_gemm_transformer))>;
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
using EGridDesc_M_N =
remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>(dummy_conv_to_gemm_transformer))>;
static auto
GenerateConvToGemmTransforms(ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformer_base,
const ADataType* a_grid_ptr_base,
DsPointer ds_grid_ptr_base,
EDataType* c_grid_ptr_base)
{
// Max number of splits
@@ -279,11 +315,13 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
// Arrays to store transformers with smaller descs than 2GB
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformers_arr;
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs_arr;
Array<DsPointer, MaxGemmsNum> ds_grid_ptrs_arr;
Array<EDataType*, MaxGemmsNum> c_grid_ptrs_arr;
// Queue for spliting
std::queue<ConvToGemmFwdTransformerLongIndexT> conv_to_gemm_transformers_queue(
{conv_to_gemm_transformer_base});
std::queue<const ADataType*> a_grid_ptrs_queue({a_grid_ptr_base});
std::queue<DsPointer> ds_grid_ptrs_queue({ds_grid_ptr_base});
std::queue<EDataType*> c_grid_ptrs_queue({c_grid_ptr_base});
index_t gemms_number = 0;
@@ -300,6 +338,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
// Get transformer from the queue
const auto& conv_to_gemm_transformer = conv_to_gemm_transformers_queue.front();
const ADataType* a_grid_ptr = a_grid_ptrs_queue.front();
DsPointer ds_grid_ptr = ds_grid_ptrs_queue.front();
EDataType* c_grid_ptr = c_grid_ptrs_queue.front();
// Check if convolution not exceed 2GB
@@ -308,8 +347,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
// If yes, push into result array
conv_to_gemm_transformers_arr(gemms_number) =
ConvToGemmFwdTransformerIndexT{conv_to_gemm_transformer};
a_grid_ptrs_arr(gemms_number) = a_grid_ptr;
c_grid_ptrs_arr(gemms_number) = c_grid_ptr;
a_grid_ptrs_arr(gemms_number) = a_grid_ptr;
ds_grid_ptrs_arr(gemms_number) = ds_grid_ptr;
c_grid_ptrs_arr(gemms_number) = c_grid_ptr;
gemms_number++;
}
else
@@ -318,19 +358,23 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformers_left_part,
conv_to_gemm_transformers_right_part;
const ADataType* a_grid_right_ptr;
DsPointer ds_grid_right_ptr;
EDataType* c_grid_right_ptr;
ck::tie(conv_to_gemm_transformers_left_part,
conv_to_gemm_transformers_right_part,
a_grid_right_ptr,
ds_grid_right_ptr,
c_grid_right_ptr) =
conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, c_grid_ptr);
conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, ds_grid_ptr, c_grid_ptr);
conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_left_part);
conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_right_part);
// Left offsets remain the same
a_grid_ptrs_queue.push(a_grid_ptr);
a_grid_ptrs_queue.push(a_grid_right_ptr);
ds_grid_ptrs_queue.push(ds_grid_ptr);
ds_grid_ptrs_queue.push(ds_grid_right_ptr);
c_grid_ptrs_queue.push(c_grid_ptr);
c_grid_ptrs_queue.push(c_grid_right_ptr);
split_numbers++;
@@ -338,6 +382,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
// Remove from the queue
conv_to_gemm_transformers_queue.pop();
a_grid_ptrs_queue.pop();
ds_grid_ptrs_queue.pop();
c_grid_ptrs_queue.pop();
}
@@ -345,6 +390,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
return ck::make_tuple(conv_to_gemm_transformers_arr,
a_grid_ptrs_arr,
ds_grid_ptrs_arr,
c_grid_ptrs_arr,
gemms_number,
is_split_valid);
@@ -375,6 +421,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
@@ -388,11 +437,14 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
// pointers
const ADataType* a_ptr_;
const BDataType* b_ptr_;
DsPointer ds_ptr_;
EDataType* e_ptr_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
@@ -405,16 +457,16 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
{
Argument(const void* p_a,
const void* p_b,
const std::array<const void*, NumDTensor>& /*p_ds*/,
const std::array<const void*, NumDTensor>& p_ds,
void* p_e,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
/*ds_g_n_k_wos_lengths*/,
ds_g_n_k_wos_lengths,
const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
/*ds_g_n_k_wos_strides*/,
ds_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
@@ -434,6 +486,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
conv_filter_strides_{conv_filter_strides},
@@ -441,94 +495,105 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
if constexpr(NumDTensor == 0)
// Perform grouped gemm, generate array of tranformer for convolution
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformer_arr;
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs;
Array<DsPointer, MaxGemmsNum> ds_grid_ptrs;
Array<EDataType*, MaxGemmsNum> c_grid_ptrs;
DsPointer p_ds_casted = CastDsPointers(p_ds);
ck::tie(conv_to_gemm_transformer_arr,
a_grid_ptrs,
ds_grid_ptrs,
c_grid_ptrs,
gemms_count_,
is_split_valid_) =
GenerateConvToGemmTransforms(
ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_,
a_g_n_c_wis_strides_,
b_g_k_c_xs_lengths_,
b_g_k_c_xs_strides_,
e_g_n_k_wos_lengths_,
e_g_n_k_wos_strides_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_},
static_cast<const ADataType*>(p_a),
p_ds_casted,
static_cast<EDataType*>(p_e));
grid_size_ = 0;
valid_gemms_count_ = 0;
if(is_split_valid_)
{
// Perform grouped gemm, generate array of tranformer for convolution
Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformer_arr;
Array<const ADataType*, MaxGemmsNum> a_grid_ptrs;
Array<EDataType*, MaxGemmsNum> c_grid_ptrs;
ck::tie(conv_to_gemm_transformer_arr,
a_grid_ptrs,
c_grid_ptrs,
gemms_count_,
is_split_valid_) =
GenerateConvToGemmTransforms(
ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_,
a_g_n_c_wis_strides_,
b_g_k_c_xs_lengths_,
b_g_k_c_xs_strides_,
e_g_n_k_wos_lengths_,
e_g_n_k_wos_strides_,
conv_filter_strides_,
conv_filter_dilations_,
input_left_pads_,
input_right_pads_},
static_cast<const ADataType*>(p_a),
static_cast<EDataType*>(p_e));
grid_size_ = 0;
valid_gemms_count_ = 0;
if(is_split_valid_)
// Create GemmArg for each gemm(conv)
for(index_t i = 0; i < gemms_count_; i++)
{
// Create GemmArg for each gemm(conv)
for(index_t i = 0; i < gemms_count_; i++)
const AGridDesc_M_K a_grid_desc_m_k{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(
conv_to_gemm_transformer_arr[i])};
const BGridDesc_N_K b_grid_desc_n_k{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(
conv_to_gemm_transformer_arr[i])};
const auto e_grid_desc_m_n =
DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_arr[i]);
const auto ds_grid_desc_m_n =
generate_tuple([&](auto) { return e_grid_desc_m_n; }, Number<NumDTensor>{});
const auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
const index_t grid_size_grp =
block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
block_2_etile_map))
{
const AGridDesc_M_K a_grid_desc_m_k{
DeviceOp::MakeAGridDescriptor_M_K<ALayout>(
conv_to_gemm_transformer_arr[i])};
const BGridDesc_N_K b_grid_desc_n_k{
DeviceOp::MakeBGridDescriptor_N_K<BLayout>(
conv_to_gemm_transformer_arr[i])};
const auto e_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
conv_to_gemm_transformer_arr[i]);
gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{
a_grid_ptrs[i],
static_cast<const BDataType*>(p_b),
ds_grid_ptrs[i],
c_grid_ptrs[i],
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k),
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k),
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n),
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n),
block_2_etile_map,
BlockStart,
BlockEnd};
const auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
const index_t grid_size_grp =
block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
Tuple<>{},
e_grid_desc_m_n,
block_2_etile_map))
{
gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{
a_grid_ptrs[i],
static_cast<const BDataType*>(p_b),
c_grid_ptrs[i],
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k),
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k),
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n),
block_2_etile_map,
BlockStart,
BlockEnd};
valid_gemms_count_++;
}
valid_gemms_count_++;
}
// N is the same for all convs
conv_N_per_block_ = static_cast<index_t>(conv_to_gemm_transformer_arr[I0].N_);
}
// Strides for G and N remain the same
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
// N is the same for all convs
conv_N_per_block_ = static_cast<index_t>(conv_to_gemm_transformer_arr[I0].N_);
}
// Strides for G and N remain the same
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
static_for<0, NumDTensor, 1>{}([&](auto i) {
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
});
}
void Print() const
@@ -558,8 +623,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
bool is_split_valid_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_groups_;
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_groups_;
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_n_;
// element-wise op
AElementwiseOperation a_element_op_;
@@ -571,6 +636,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
std::array<long_index_t, NDimSpatial> conv_filter_strides_;
@@ -584,63 +651,55 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
{
float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if constexpr(NumDTensor == 0)
if(stream_config.log_level_ > 0)
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
arg.Print();
}
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdx = arg.grid_size_;
const index_t gdy = arg.num_group_;
const index_t gdz = num_workgroups_per_Conv_N;
const index_t gdx = arg.grid_size_;
const index_t gdy = arg.num_group_;
const index_t gdz = num_workgroups_per_Conv_N;
// K is constant for all gemms
const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2);
// K is constant for all gemms
const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel =
kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle<
GridwiseGemm,
MaxGemmsNum,
GemmArgs,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
has_main_loop>;
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle<
GridwiseGemm,
MaxGemmsNum,
GemmArgs,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.gemm_desc_kernel_args_,
arg.gemms_count_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
};
return launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.gemm_desc_kernel_args_,
arg.gemms_count_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return 0.f;
return launch_kernel(integral_constant<bool, false>{});
}
}
@@ -657,9 +716,26 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
const long_index_t K = arg.b_g_k_c_xs_lengths_[I1];
const long_index_t C = arg.b_g_k_c_xs_lengths_[I2];
// Move this to runtime check to align Conv instances
// with Conv Multiple D instances
if constexpr(NumDTensor != 0)
bool ds_valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
for(int d = 0; d < NDimSpatial + I3; d++)
{
if(arg.ds_g_n_k_wos_strides_[i][d] != arg.e_g_n_k_wos_strides_[d])
{
ds_valid = false;
}
if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d])
{
ds_valid = false;
}
}
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
static_assert(is_same_v<DDataType, EDataType>);
});
if(!ds_valid)
{
return false;
}