mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Support multi AB for grouped conv fwd xdl (#1027)
* Support multi AB for grouped conv fwd xdl * Add instances * Add client example * Add example * Add interface test * Minor fixes Minor fixes Minor fixes * Comment fixes * Fixes * Reference fix * Test xdl fixes * Improve multi_ab interface test
This commit is contained in:
@@ -263,19 +263,18 @@ struct DeviceColumnToImageImpl
|
||||
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, InputGridDesc>(
|
||||
InputGridDesc{}))>;
|
||||
|
||||
using GridwiseTensorRearrangeKernel =
|
||||
GridwiseTensorRearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
ThreadClusterLengths,
|
||||
ScalarPerVector,
|
||||
InMemoryDataOperationEnum::Add,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<I0>>;
|
||||
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
ThreadClusterLengths,
|
||||
ScalarPerVector,
|
||||
InMemoryDataOperationEnum::Add,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<>>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
@@ -453,7 +452,7 @@ struct DeviceColumnToImageImpl
|
||||
std::vector<const InputDataType*> p_in_container_;
|
||||
std::vector<OutputDataType*> p_out_container_;
|
||||
|
||||
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
@@ -471,7 +470,7 @@ struct DeviceColumnToImageImpl
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<I0>,
|
||||
ComputePtrOffsetOfStridedBatch<>,
|
||||
GridwiseTensorRearrangeKernel>;
|
||||
|
||||
// Execute each set of independent filters
|
||||
|
||||
@@ -385,9 +385,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
|
||||
|
||||
// desc for blockwise copy
|
||||
using AsGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(AsGridDesc_M_K{}))>;
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(
|
||||
AsGridDesc_M_K{}))>;
|
||||
using BsGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(BsGridDesc_N_K{}))>;
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(
|
||||
BsGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
@@ -397,7 +399,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -429,7 +431,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
|
||||
bs_grid_desc_bk0_n_bk1_{},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseGemm::MakeBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op}
|
||||
@@ -481,10 +483,10 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
|
||||
block_2_etile_map_))
|
||||
{
|
||||
as_grid_desc_ak0_m_ak1_ =
|
||||
GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_);
|
||||
GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_);
|
||||
|
||||
bs_grid_desc_bk0_n_bk1_ =
|
||||
GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_);
|
||||
GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
|
||||
@@ -305,9 +305,11 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
|
||||
|
||||
// desc for blockwise copy
|
||||
using AsGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(AsGridDesc_M_K{}))>;
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(
|
||||
AsGridDesc_M_K{}))>;
|
||||
using BsGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(BsGridDesc_N_K{}))>;
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(
|
||||
BsGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
@@ -317,7 +319,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -349,7 +351,7 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
|
||||
bs_grid_desc_bk0_n_bk1_{},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseGemm::MakeBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
@@ -407,10 +409,10 @@ struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD<AsLayou
|
||||
block_2_etile_map_))
|
||||
{
|
||||
as_grid_desc_ak0_m_ak1_ =
|
||||
GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_);
|
||||
GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_);
|
||||
|
||||
bs_grid_desc_bk0_n_bk1_ =
|
||||
GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_);
|
||||
GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
|
||||
@@ -517,7 +517,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOp a_element_op_;
|
||||
@@ -579,7 +579,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
|
||||
@@ -677,7 +677,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
std::vector<Block2ETileMap> block_2_etile_map_container_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOp a_element_op_;
|
||||
@@ -746,7 +746,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
|
||||
@@ -927,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
|
||||
|
||||
// element-wise op
|
||||
OutElementwiseOperation a_element_op_;
|
||||
@@ -999,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch<I0>,
|
||||
ComputePtrOffsetOfStridedBatch<>,
|
||||
has_main_loop,
|
||||
has_double_loop>;
|
||||
|
||||
|
||||
@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
|
||||
|
||||
OutElementwiseOperation a_element_op_;
|
||||
InElementwiseOperation b_element_op_;
|
||||
@@ -647,7 +647,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch<I0>,
|
||||
ComputePtrOffsetOfStridedBatch<>,
|
||||
has_main_loop>;
|
||||
|
||||
using EmptyTuple = Tuple<>;
|
||||
|
||||
@@ -1197,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
|
||||
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
@@ -1276,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch<I0>,
|
||||
ComputePtrOffsetOfStridedBatch<>,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
|
||||
@@ -537,7 +537,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
DefaultBlock2CTileMap block_2_ctile_map_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
@@ -601,7 +601,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11,
|
||||
DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11,
|
||||
DefaultBlock2CTileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
has_main_loop,
|
||||
has_double_loop>;
|
||||
|
||||
|
||||
@@ -428,7 +428,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
typename GridwiseOp::DefaultBlock2CTileMap block_2_etile_map_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
@@ -485,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
remove_reference_t<typename GridwiseOp::DefaultBlock2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -56,7 +57,8 @@ namespace {
|
||||
*
|
||||
*/
|
||||
template <typename GridwiseGemm,
|
||||
typename ABDataType,
|
||||
typename AsPointer, // tuples if multi AB, pointers if no
|
||||
typename BsPointer,
|
||||
typename DsPointer,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
@@ -68,14 +70,16 @@ template <typename GridwiseGemm,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2ETileMap,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
bool HasMainKBlockLoop>
|
||||
bool HasMainKBlockLoop,
|
||||
bool isMultiA,
|
||||
bool isMultiB>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle(
|
||||
const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
AsPointer p_as_grid,
|
||||
BsPointer p_bs_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
@@ -98,14 +102,9 @@ __global__ void
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
@@ -117,22 +116,63 @@ __global__ void
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
block_2_ctile_map);
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
AsPointer p_as_grid_grp;
|
||||
BsPointer p_bs_grid_grp;
|
||||
|
||||
const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx);
|
||||
|
||||
static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
|
||||
static_for<0, NumATensor, 1>{}(
|
||||
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; });
|
||||
|
||||
const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx);
|
||||
|
||||
static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; });
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
p_as_grid_grp,
|
||||
p_bs_grid_grp,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
p_as_grid + a_batch_offset,
|
||||
p_bs_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_as_grid;
|
||||
ignore = p_bs_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = batch_count;
|
||||
@@ -150,6 +190,9 @@ __global__ void
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
//
|
||||
// @brief Device Convolution operation.
|
||||
//
|
||||
@@ -211,8 +254,13 @@ template <index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename ComputeDataType = ADataType,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
typename ComputeDataType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
ADataType>()), // ComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
: public DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
ALayout,
|
||||
@@ -230,6 +278,11 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle;
|
||||
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
|
||||
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
|
||||
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -325,51 +378,43 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
NumGemmKPrefetchStage,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
|
||||
// it to it
|
||||
using GemmADataType = std::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
|
||||
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
|
||||
|
||||
#define GridwiseGemmTemplateParameters \
|
||||
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
|
||||
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
|
||||
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
|
||||
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
|
||||
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm =
|
||||
std::conditional_t<isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmTemplateParameters>,
|
||||
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
|
||||
|
||||
// If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers.
|
||||
using APointers =
|
||||
std::conditional_t<isMultiA, std::array<const void*, NumATensor>&, const void*>;
|
||||
using BPointers =
|
||||
std::conditional_t<isMultiB, std::array<const void*, NumBTensor>&, const void*>;
|
||||
// Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
|
||||
// in initializer list what is required for single const pointer).
|
||||
using AGridPointer = remove_cvref_t<
|
||||
decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>;
|
||||
using BGridPointer = remove_cvref_t<
|
||||
decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
@@ -392,8 +437,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_a,
|
||||
const void* p_b,
|
||||
Argument(APointers p_as,
|
||||
BPointers p_bs,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
@@ -413,8 +458,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
: p_a_grid_{static_cast<const ADataType*>(p_a)},
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b)},
|
||||
: p_as_grid_{},
|
||||
p_bs_grid_{},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_c_wis_lengths[0]},
|
||||
@@ -458,9 +503,58 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
// A/B/E Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
// Init compute_ptr_offset_of_batch_ for multiple AB
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0];
|
||||
|
||||
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
|
||||
// type is not tuple)
|
||||
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmADataType>>;
|
||||
// It is possible that one of the AB is a pointer and one is a tuple.
|
||||
// Then also use multiAB but we have to cast single pointer instead of tuple of
|
||||
// pointer.
|
||||
if constexpr(isMultiA)
|
||||
{
|
||||
// p_as is tuple
|
||||
p_as_grid_(i) = static_cast<const DataType*>(p_as[i.value]);
|
||||
}
|
||||
else
|
||||
{
|
||||
// if MultiB and not MultiA then p_as is single pointer
|
||||
p_as_grid_(i) = static_cast<const DataType*>(p_as);
|
||||
}
|
||||
});
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
// Init compute_ptr_offset_of_batch_ for multiple AB
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0];
|
||||
|
||||
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
|
||||
// It is possible that one of the AB is a pointer and one is a tuple.
|
||||
// Then also use multiAB but we have to cast single pointer instead of tuple of
|
||||
// pointer.
|
||||
if constexpr(isMultiB)
|
||||
{
|
||||
// p_bs is tuple
|
||||
p_bs_grid_(i) = static_cast<const DataType*>(p_bs[i.value]);
|
||||
}
|
||||
else
|
||||
{
|
||||
// if MultiA and not MultiB then p_bs is single pointer
|
||||
p_bs_grid_(i) = static_cast<const DataType*>(p_bs);
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
|
||||
// p_as and p_bs are pointers
|
||||
p_as_grid_(I0) = static_cast<const ADataType*>(p_as);
|
||||
p_bs_grid_(I0) = static_cast<const BDataType*>(p_bs);
|
||||
}
|
||||
|
||||
// populate pointer, batch stride, desc for Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
@@ -477,21 +571,47 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
|
||||
ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]);
|
||||
});
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
|
||||
// populate desc for Ds/E
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
const auto as_grid_desc_ak0_m_ak1 =
|
||||
generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number<NumATensor>{});
|
||||
const auto bs_grid_desc_bk0_n_bk1 =
|
||||
generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number<NumBTensor>{});
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -505,9 +625,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// private:
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
// pointers (tuple if multi AB, pointer if no)
|
||||
AGridPointer p_as_grid_;
|
||||
BGridPointer p_bs_grid_;
|
||||
typename GridwiseGemm::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
@@ -529,7 +649,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
|
||||
compute_ptr_offset_of_batch_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
@@ -563,16 +684,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_;
|
||||
|
||||
@@ -582,41 +693,96 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor>,
|
||||
has_main_loop>;
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
// Generate tuples with grid descriptors for each A and B
|
||||
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
|
||||
[&](auto) { return arg.a_grid_desc_ak0_m_ak1_; }, Number<NumATensor>{});
|
||||
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
|
||||
[&](auto) { return arg.b_grid_desc_bk0_n_bk1_; }, Number<NumBTensor>{});
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_g_n_c_wis_lengths_[0], // Group count
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
AGridPointer,
|
||||
BGridPointer,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_as_grid_,
|
||||
arg.p_bs_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_g_n_c_wis_lengths_[0], // Group count
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
const ADataType*,
|
||||
const BDataType*,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
|
||||
has_main_loop,
|
||||
isMultiA,
|
||||
isMultiB>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_as_grid_.At(I0), // Pass just A descriptor instead of tuple
|
||||
arg.p_bs_grid_.At(I0), // Pass just B descriptor instead of tuple
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_g_n_c_wis_lengths_[0], // Group count
|
||||
arg.a_grid_desc_ak0_m_ak1_,
|
||||
arg.b_grid_desc_bk0_n_bk1_,
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_etile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
}
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
@@ -791,11 +957,27 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
// check Gridwise GEMM
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
// Genarate tuples with the same descriptors
|
||||
const auto as_grid_desc_ak0_m_ak1 =
|
||||
generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number<NumATensor>{});
|
||||
const auto bs_grid_desc_bk0_n_bk1 =
|
||||
generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number<NumBTensor>{});
|
||||
return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_n_k_,
|
||||
arg.ds_grid_desc_m_n_,
|
||||
arg.e_grid_desc_m_n_,
|
||||
arg.block_2_etile_map_);
|
||||
}
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
@@ -804,8 +986,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
}
|
||||
|
||||
static auto MakeArgument(
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
APointers p_as,
|
||||
BPointers p_bs,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
@@ -824,8 +1006,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
return Argument{p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths,
|
||||
@@ -848,8 +1030,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
APointers p_a,
|
||||
BPointers p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
|
||||
@@ -9,8 +9,77 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t NumDTensor>
|
||||
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0, typename = void>
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
|
||||
struct ComputePtrOffsetOfStridedBatch<NumATensor,
|
||||
NumBTensor,
|
||||
NumDTensor,
|
||||
ck::enable_if_t<(NumATensor > 1 || NumBTensor > 1)>>
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch() = default;
|
||||
|
||||
ComputePtrOffsetOfStridedBatch(Array<ck::index_t, NumATensor>& BatchStrideAs,
|
||||
Array<ck::index_t, NumBTensor>& BatchStrideBs,
|
||||
Array<ck::index_t, NumDTensor>& BatchStrideDs,
|
||||
index_t BatchStrideE)
|
||||
: BatchStrideA_(BatchStrideAs),
|
||||
BatchStrideB_(BatchStrideBs),
|
||||
BatchStrideDs_(BatchStrideDs),
|
||||
BatchStrideE_(BatchStrideE)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetAsPtrOffset(index_t g_idx) const
|
||||
{
|
||||
Array<long_index_t, NumATensor> as_offset;
|
||||
static_for<0, NumATensor, 1>{}(
|
||||
[&](auto i) { as_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideA_[i]); });
|
||||
return as_offset;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetBsPtrOffset(index_t g_idx) const
|
||||
{
|
||||
Array<long_index_t, NumBTensor> bs_offset;
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { bs_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideB_[i]); });
|
||||
return bs_offset;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
|
||||
{
|
||||
Array<long_index_t, NumDTensor> ds_offset;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
|
||||
return ds_offset;
|
||||
}
|
||||
|
||||
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
}
|
||||
|
||||
// alias for kernels without multiple D
|
||||
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
}
|
||||
|
||||
Array<ck::index_t, NumATensor> BatchStrideA_;
|
||||
Array<ck::index_t, NumBTensor> BatchStrideB_;
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideE_;
|
||||
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
|
||||
};
|
||||
|
||||
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
|
||||
struct ComputePtrOffsetOfStridedBatch<NumATensor,
|
||||
NumBTensor,
|
||||
NumDTensor,
|
||||
ck::enable_if_t<(NumATensor == 1 && NumBTensor == 1)>>
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch() = default;
|
||||
|
||||
@@ -54,13 +123,67 @@ struct ComputePtrOffsetOfStridedBatch
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
}
|
||||
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
ck::index_t BatchStrideA_;
|
||||
ck::index_t BatchStrideB_;
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideE_;
|
||||
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
|
||||
};
|
||||
|
||||
template <bool isTuple, typename Tensors>
|
||||
constexpr static auto GetNumABTensors()
|
||||
{
|
||||
if constexpr(isTuple)
|
||||
{
|
||||
return Number<Tensors::Size()>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<1>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <bool isTuple, typename GridwiseGemm, typename DataType>
|
||||
constexpr static auto GetAGridPointer()
|
||||
{
|
||||
if constexpr(isTuple)
|
||||
{
|
||||
return typename GridwiseGemm::AsGridPointer{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Tuple<const DataType*>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <bool isTuple, typename GridwiseGemm, typename DataType>
|
||||
constexpr static auto GetBGridPointer()
|
||||
{
|
||||
if constexpr(isTuple)
|
||||
{
|
||||
return typename GridwiseGemm::BsGridPointer{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Tuple<const DataType*>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <bool isTuple, typename Id, typename Type>
|
||||
constexpr static auto UnpackDataType()
|
||||
{
|
||||
if constexpr(isTuple)
|
||||
{
|
||||
// unpack if tuple
|
||||
return tuple_element_t<Id{}, Type>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
// if no, return Type
|
||||
return Type{};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -142,19 +142,18 @@ struct DeviceImageToColumnImpl
|
||||
decltype(BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, KPerBlock, OutputGridDesc>(
|
||||
OutputGridDesc{}))>;
|
||||
|
||||
using GridwiseTensorRearrangeKernel =
|
||||
GridwiseTensorRearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
ThreadClusterLengths,
|
||||
ScalarPerVector,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<I0>>;
|
||||
using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange<InputGridDesc,
|
||||
InputDataType,
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
ThreadClusterLengths,
|
||||
ScalarPerVector,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<>>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
@@ -224,7 +223,7 @@ struct DeviceImageToColumnImpl
|
||||
InputGridDesc in_grid_desc_m_k_;
|
||||
OutputGridDesc out_grid_desc_m_k_;
|
||||
|
||||
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
@@ -246,7 +245,7 @@ struct DeviceImageToColumnImpl
|
||||
OutputGridDesc,
|
||||
OutputDataType,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<I0>,
|
||||
ComputePtrOffsetOfStridedBatch<>,
|
||||
GridwiseTensorRearrangeKernel>;
|
||||
|
||||
float elapsed_time = launch_and_time_kernel(stream_config,
|
||||
|
||||
Reference in New Issue
Block a user