mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Grouped conv bwd weight with grouped gemm (#2304)
* Grouped conv bwd weight with grouped gemm * fixes * fix * Fixes * test comments * restore atol * fix
This commit is contained in:
@@ -25,6 +25,8 @@
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -51,6 +53,11 @@ namespace {
|
||||
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
|
||||
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
|
||||
*
|
||||
* MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this
|
||||
* implementation we can avoid copy data to workspace before kernel launch since number of groups is
|
||||
* runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then we run this
|
||||
* kernel in the loop.
|
||||
*
|
||||
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
|
||||
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
|
||||
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
|
||||
@@ -60,17 +67,13 @@ template <typename GridwiseGemm,
|
||||
typename ABDataType,
|
||||
typename DsPointer,
|
||||
typename EDataType,
|
||||
index_t MaxGroupedGemmGroupsNum,
|
||||
typename GemmArgs,
|
||||
typename AElementwiseOp,
|
||||
typename BElementwiseOp,
|
||||
typename CDEElementwiseOp,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2ETileMap,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename ComputePtrOffsetOfN,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum OutElementOp>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -81,25 +84,21 @@ __global__ void
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const std::array<GemmArgs, MaxGroupedGemmGroupsNum> gemm_kernel_args,
|
||||
const index_t gemms_count,
|
||||
const AElementwiseOp a_element_op,
|
||||
const BElementwiseOp b_element_op,
|
||||
const CDEElementwiseOp cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const ComputePtrOffsetOfN compute_ptr_offset_of_n,
|
||||
const index_t KBatch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch);
|
||||
const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
@@ -119,43 +118,79 @@ __global__ void
|
||||
|
||||
DsPointer p_ds_grid_grp;
|
||||
|
||||
static constexpr index_t NumDTensor =
|
||||
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
|
||||
static constexpr index_t NumDTensor = DsPointer::Size();
|
||||
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, OutElementOp>(
|
||||
p_a_grid + a_batch_offset + a_n_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
block_2_ctile_map,
|
||||
KBatch,
|
||||
k_idx);
|
||||
index_t left = 0;
|
||||
index_t right = gemms_count;
|
||||
index_t group_id = index_t((left + right) / 2);
|
||||
while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ &&
|
||||
block_args_id < gemm_kernel_args[group_id].BlockEnd_)) &&
|
||||
left <= right)
|
||||
{
|
||||
if(block_args_id < gemm_kernel_args[group_id].BlockStart_)
|
||||
{
|
||||
right = group_id;
|
||||
}
|
||||
else
|
||||
{
|
||||
left = group_id;
|
||||
}
|
||||
group_id = index_t((left + right) / 2);
|
||||
}
|
||||
|
||||
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
|
||||
{
|
||||
GridwiseGemm::template Run<true, OutElementOp>(
|
||||
p_a_grid + a_batch_offset + a_n_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].block_2_ctile_map_,
|
||||
KBatch,
|
||||
k_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
GridwiseGemm::template Run<false, OutElementOp>(
|
||||
p_a_grid + a_batch_offset + a_n_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
|
||||
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
gemm_kernel_args[group_id].block_2_ctile_map_,
|
||||
KBatch,
|
||||
k_idx);
|
||||
}
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
ignore = gemm_kernel_args;
|
||||
ignore = gemms_count;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = compute_ptr_offset_of_n;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -239,6 +274,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
static_assert(NDimSpatial == 2 || NDimSpatial == 3,
|
||||
"wrong! only implemented for 2D and 3D now");
|
||||
|
||||
// MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this
|
||||
// implementation we can avoid copy data to workspace before kernel launch since number of
|
||||
// groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then
|
||||
// we run this kernel in the loop.
|
||||
static constexpr index_t MaxGroupedGemmGroupsNum = 32;
|
||||
|
||||
using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -378,15 +419,58 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{}));
|
||||
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmMultiDTemplateParams>::
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}));
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}));
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap = remove_cvref_t<
|
||||
decltype(GridwiseGemmMultipleD_xdl_cshuffle<
|
||||
GridwiseGemmMultiDTemplateParams>::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
|
||||
|
||||
struct GemmArgs
|
||||
{
|
||||
GemmArgs() = default;
|
||||
GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
GroupedGemmBlock2ETileMap block_2_ctile_map,
|
||||
index_t BlockStart,
|
||||
index_t BlockEnd,
|
||||
bool HasMainKBlockLoop)
|
||||
: a_grid_desc_ak0_m_ak1_(a_grid_desc_ak0_m_ak1),
|
||||
b_grid_desc_bk0_n_bk1_(b_grid_desc_bk0_n_bk1),
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_(
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_(
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
|
||||
// block-to-e-tile map
|
||||
block_2_ctile_map_(block_2_ctile_map),
|
||||
BlockStart_(BlockStart),
|
||||
BlockEnd_(BlockEnd),
|
||||
HasMainKBlockLoop_(HasMainKBlockLoop)
|
||||
|
||||
{
|
||||
}
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
GroupedGemmBlock2ETileMap block_2_ctile_map_;
|
||||
index_t BlockStart_, BlockEnd_;
|
||||
bool HasMainKBlockLoop_;
|
||||
};
|
||||
using Block2TileMapInOutElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, MPerBlock>;
|
||||
using Block2TileMapWeiElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
|
||||
|
||||
@@ -589,9 +673,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
index_t grid_size = 0;
|
||||
// Allocate place for sets of gemms
|
||||
gemm_kernel_args_.resize(
|
||||
math::integer_divide_ceil(ZTilde * YTilde * XTilde, MaxGroupedGemmGroupsNum));
|
||||
|
||||
for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
|
||||
{
|
||||
|
||||
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
|
||||
{
|
||||
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
|
||||
@@ -694,36 +782,51 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
|
||||
e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
|
||||
|
||||
// desc for blockwise copy
|
||||
a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1);
|
||||
b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1);
|
||||
const index_t grid_size_grp = Block2ETileMap::CalculateGridSize(
|
||||
e_grid_desc_m_n.GetLength(I0), e_grid_desc_m_n.GetLength(I1));
|
||||
|
||||
// block-to-e-tile-map
|
||||
auto block_2_etile_map =
|
||||
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
|
||||
const index_t BlockStart = grid_size;
|
||||
const index_t BlockEnd = grid_size + grid_size_grp;
|
||||
|
||||
block_2_etile_map_container_.push_back(block_2_etile_map);
|
||||
grid_size += grid_size_grp;
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
|
||||
b_grid_desc_n_k,
|
||||
ds_grid_desc_m_n,
|
||||
e_grid_desc_m_n,
|
||||
block_2_etile_map,
|
||||
k_batch_))
|
||||
// block-to-e-tile map
|
||||
const auto block_2_etile_map =
|
||||
GroupedGemmBlock2ETileMap(Block2ETileMap(e_grid_desc_m_n.GetLength(I0),
|
||||
e_grid_desc_m_n.GetLength(I1)),
|
||||
BlockStart);
|
||||
|
||||
const auto GemmK = a_grid_desc_m_k.GetLength(I1);
|
||||
const bool HasMainKBlockLoop =
|
||||
GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, k_batch_);
|
||||
|
||||
gemm_kernel_args_[gemms_count_ /
|
||||
MaxGroupedGemmGroupsNum][gemms_count_ %
|
||||
MaxGroupedGemmGroupsNum] =
|
||||
GemmArgs{a_grid_desc_ak0_m_ak1,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
GridwiseGemm::
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n),
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n),
|
||||
block_2_etile_map,
|
||||
BlockStart,
|
||||
BlockEnd,
|
||||
HasMainKBlockLoop};
|
||||
gemms_count_++;
|
||||
if(gemms_count_ % MaxGroupedGemmGroupsNum == 0)
|
||||
{
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
|
||||
|
||||
GridwiseGemm::
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n));
|
||||
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n));
|
||||
gemms_grid_size_.push_back(grid_size);
|
||||
grid_size = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
gemm_kernel_args_.resize(
|
||||
math::integer_divide_ceil(gemms_count_, MaxGroupedGemmGroupsNum));
|
||||
gemms_grid_size_.push_back(grid_size);
|
||||
|
||||
// A/B/Ds/E Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides_transposed[0];
|
||||
@@ -830,31 +933,28 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
void Print() const
|
||||
{
|
||||
for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++)
|
||||
for(std::size_t i = 0; i < a_grid_desc_m_k_container_.size(); i++)
|
||||
{
|
||||
std::cout << "a_grid_desc_ak0_m_ak1_container_"
|
||||
<< a_grid_desc_ak0_m_ak1_container_[i] << std::endl;
|
||||
std::cout << "a_grid_desc_m_ak_container_" << a_grid_desc_m_k_container_[i]
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "b_grid_desc_bk0_n_bk1_container_"
|
||||
<< b_grid_desc_bk0_n_bk1_container_[i] << std::endl;
|
||||
std::cout << "b_grid_desc_n_bk_container_" << b_grid_desc_n_k_container_[i]
|
||||
<< std::endl;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_"
|
||||
<< ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i][j]
|
||||
<< std::endl;
|
||||
<< ds_grid_desc_m_n_container_[i][j] << std::endl;
|
||||
});
|
||||
|
||||
std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_"
|
||||
<< e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i]
|
||||
<< std::endl;
|
||||
<< e_grid_desc_m_n_container_[i] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
typename GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmMultiDTemplateParams>::DsGridPointer
|
||||
p_ds_grid_;
|
||||
typename GridwiseGemm::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// tensor descriptor for problem definition
|
||||
@@ -865,16 +965,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
|
||||
std::vector<EGridDesc_M_N> e_grid_desc_m_n_container_;
|
||||
|
||||
// tensor descriptor for block-wise copy
|
||||
std::vector<AGridDesc_AK0_M_AK1> a_grid_desc_ak0_m_ak1_container_;
|
||||
std::vector<BGridDesc_BK0_N_BK1> b_grid_desc_bk0_n_bk1_container_;
|
||||
std::vector<DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_;
|
||||
std::vector<EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_container_;
|
||||
|
||||
// block-to-e-tile map
|
||||
std::vector<Block2ETileMap> block_2_etile_map_container_;
|
||||
Block2TileMapInOutElementwise elementwise_block_2_ctile_map_transpose_a_,
|
||||
elementwise_block_2_ctile_map_transpose_e_;
|
||||
Block2TileMapWeiElementwise elementwise_block_2_ctile_map_transpose_b_;
|
||||
@@ -903,6 +994,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
const index_t k_batch_;
|
||||
index_t num_workgroups_per_Conv_N_;
|
||||
std::vector<index_t> gemms_grid_size_;
|
||||
index_t gemms_count_ = 0;
|
||||
std::vector<std::array<GemmArgs, MaxGroupedGemmGroupsNum>> gemm_kernel_args_;
|
||||
|
||||
bool bwd_needs_zero_out;
|
||||
long_index_t e_space_size_bytes;
|
||||
};
|
||||
@@ -941,84 +1036,61 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
|
||||
}
|
||||
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
|
||||
for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size();
|
||||
gemm_set_id++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
|
||||
arg.b_grid_desc_n_k_container_[i],
|
||||
arg.ds_grid_desc_m_n_container_[i],
|
||||
arg.e_grid_desc_m_n_container_[i],
|
||||
arg.block_2_etile_map_container_[i],
|
||||
arg.k_batch_))
|
||||
{
|
||||
throw std::runtime_error("wrong! device_op has invalid setting");
|
||||
}
|
||||
|
||||
const index_t gdx = arg.block_2_etile_map_container_[i].CalculateGridSize(
|
||||
arg.e_grid_desc_m_n_container_[i]);
|
||||
|
||||
const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1);
|
||||
const index_t gdx = arg.gemms_grid_size_[gemm_set_id];
|
||||
const index_t gemms_count_for_set =
|
||||
gemm_set_id == arg.gemm_kernel_args_.size() - 1
|
||||
? arg.gemms_count_ - MaxGroupedGemmGroupsNum * gemm_set_id
|
||||
: MaxGroupedGemmGroupsNum;
|
||||
const std::array<GemmArgs, MaxGroupedGemmGroupsNum>& gemm_kernel_args =
|
||||
arg.gemm_kernel_args_[gemm_set_id];
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
if(arg.bwd_needs_zero_out && i == 0)
|
||||
if(arg.bwd_needs_zero_out && gemm_set_id == 0)
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(
|
||||
p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
auto launch_kernel = [&]() {
|
||||
const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
MaxGroupedGemmGroupsNum,
|
||||
GemmArgs,
|
||||
AElementwiseOp,
|
||||
BElementwiseOp,
|
||||
CDEElementwiseOp,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
has_main_loop,
|
||||
ElementOp>;
|
||||
|
||||
return launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_grid_desc_ak0_m_ak1_container_[i],
|
||||
arg.b_grid_desc_bk0_n_bk1_container_[i],
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
|
||||
arg.block_2_etile_map_container_[i],
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.k_batch_);
|
||||
return launch_and_time_kernel_with_preprocess(stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
arg.p_ds_grid_,
|
||||
p_e_grid,
|
||||
gemm_kernel_args,
|
||||
gemms_count_for_set,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_,
|
||||
arg.k_batch_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, arg.k_batch_))
|
||||
{
|
||||
ave_time += launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time += launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
ave_time += launch_kernel();
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
@@ -1304,14 +1376,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
|
||||
arg.b_grid_desc_n_k_container_[i],
|
||||
arg.ds_grid_desc_m_n_container_[i],
|
||||
arg.e_grid_desc_m_n_container_[i],
|
||||
arg.block_2_etile_map_container_[i],
|
||||
arg.k_batch_))
|
||||
if(!GridwiseGemm::CheckValidity(
|
||||
arg.a_grid_desc_m_k_container_[i],
|
||||
arg.b_grid_desc_n_k_container_[i],
|
||||
arg.ds_grid_desc_m_n_container_[i],
|
||||
arg.e_grid_desc_m_n_container_[i],
|
||||
arg.gemm_kernel_args_[i / MaxGroupedGemmGroupsNum][i % MaxGroupedGemmGroupsNum]
|
||||
.block_2_ctile_map_,
|
||||
arg.k_batch_))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -271,6 +271,7 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default;
|
||||
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M,
|
||||
index_t N,
|
||||
index_t M01 = 8)
|
||||
@@ -870,6 +871,7 @@ struct OffsettedBlockToCTileMap
|
||||
{
|
||||
using underlying_type = UnderlyingBlockToCTileMap;
|
||||
|
||||
__host__ __device__ OffsettedBlockToCTileMap() = default;
|
||||
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
|
||||
index_t block_start)
|
||||
{
|
||||
|
||||
@@ -186,8 +186,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
rtol = std::max(rtol, rtol_split_k);
|
||||
atol = std::max(atol, atol_split_k);
|
||||
|
||||
pass = pass & ck::utils::check_err(
|
||||
in_device, in_host, "Error: Incorrect results!", rtol, atol);
|
||||
pass &= ck::utils::check_err(
|
||||
in_device, in_host, "Error: Incorrect results!", rtol, atol);
|
||||
std::cout << "Relative error threshold: " << rtol
|
||||
<< " Absolute error threshold: " << atol << std::endl;
|
||||
|
||||
|
||||
@@ -261,8 +261,9 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
ck::utils::get_absolute_threshold<WeiDataType, WeiDataType, WeiDataType>(
|
||||
max_accumulated_value, num_accums_split_k);
|
||||
// Use higher threshold
|
||||
rtol = std::max(rtol, rtol_split_k);
|
||||
atol = std::max(atol, atol_split_k);
|
||||
rtol = std::max(rtol, rtol_split_k);
|
||||
atol = std::max(atol, atol_split_k);
|
||||
// Use default atol for splitK == 1
|
||||
bool pass = ck::utils::check_err(weight_device_result,
|
||||
weight_host_result,
|
||||
"Error: Incorrect results!",
|
||||
|
||||
@@ -96,6 +96,18 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
|
||||
// GroupedGemmGroupsNum = 4, ZTilde * YTilde * XTilde = 4, MaxGroupedGemmGroupsNum = 32
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 2, 16, 16, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}});
|
||||
// GroupedGemmGroupsNum = 9, ZTilde * YTilde * XTilde = 36, MaxGroupedGemmGroupsNum = 32
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 2, 16, 16, {3, 3}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1}});
|
||||
// GroupedGemmGroupsNum = 36, ZTilde * YTilde * XTilde = 36, MaxGroupedGemmGroupsNum = 32
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 2, 16, 16, {6, 6}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1}});
|
||||
// GroupedGemmGroupsNum = 32, ZTilde * YTilde * XTilde = 32, MaxGroupedGemmGroupsNum = 32
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 2, 16, 16, {4, 8}, {28, 28}, {4, 8}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 2, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back(
|
||||
|
||||
Reference in New Issue
Block a user