Grouped conv bwd weight with grouped gemm (#2304)

* Grouped conv bwd weight with grouped gemm

* fixes

* fix

* Fixes

* test comments

* restore atol

* fix
This commit is contained in:
Bartłomiej Kocot
2025-06-12 10:15:07 +02:00
committed by GitHub
parent 8aff45a8af
commit bb4f471b09
5 changed files with 242 additions and 153 deletions

View File

@@ -25,6 +25,8 @@
#include "ck/host_utility/flush_cache.hpp"
#include "ck/host_utility/io.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
@@ -51,6 +53,11 @@ namespace {
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this
* implementation we can avoid copy data to workspace before kernel launch since number of groups is
* runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then we run this
* kernel in the loop.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
@@ -60,17 +67,13 @@ template <typename GridwiseGemm,
typename ABDataType,
typename DsPointer,
typename EDataType,
index_t MaxGroupedGemmGroupsNum,
typename GemmArgs,
typename AElementwiseOp,
typename BElementwiseOp,
typename CDEElementwiseOp,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
typename ComputePtrOffsetOfBatch,
typename ComputePtrOffsetOfN,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum OutElementOp>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
@@ -81,25 +84,21 @@ __global__ void
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const std::array<GemmArgs, MaxGroupedGemmGroupsNum> gemm_kernel_args,
const index_t gemms_count,
const AElementwiseOp a_element_op,
const BElementwiseOp b_element_op,
const CDEElementwiseOp cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const ComputePtrOffsetOfN compute_ptr_offset_of_n,
const index_t KBatch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch);
const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
@@ -119,43 +118,79 @@ __global__ void
DsPointer p_ds_grid_grp;
static constexpr index_t NumDTensor =
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
static constexpr index_t NumDTensor = DsPointer::Size();
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
GridwiseGemm::template Run<HasMainKBlockLoop, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock_,
block_2_ctile_map,
KBatch,
k_idx);
index_t left = 0;
index_t right = gemms_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ &&
block_args_id < gemm_kernel_args[group_id].BlockEnd_)) &&
left <= right)
{
if(block_args_id < gemm_kernel_args[group_id].BlockStart_)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
}
if(gemm_kernel_args[group_id].HasMainKBlockLoop_)
{
GridwiseGemm::template Run<true, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_,
KBatch,
k_idx);
}
else
{
GridwiseGemm::template Run<false, OutElementOp>(
p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_kernel_args[group_id].block_2_ctile_map_,
KBatch,
k_idx);
}
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
ignore = gemm_kernel_args;
ignore = gemms_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = compute_ptr_offset_of_batch;
ignore = compute_ptr_offset_of_n;
ignore = block_2_ctile_map;
#endif
}
@@ -239,6 +274,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
static_assert(NDimSpatial == 2 || NDimSpatial == 3,
"wrong! only implemented for 2D and 3D now");
// MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this
// implementation we can avoid copy data to workspace before kernel launch since number of
// groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then
// we run this kernel in the loop.
static constexpr index_t MaxGroupedGemmGroupsNum = 32;
using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1;
static constexpr index_t NumDTensor = DsDataType::Size();
@@ -378,15 +419,58 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{}));
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmMultiDTemplateParams>::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}));
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}));
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}));
// block-to-e-tile map
using Block2ETileMap = remove_cvref_t<
decltype(GridwiseGemmMultipleD_xdl_cshuffle<
GridwiseGemmMultiDTemplateParams>::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
struct GemmArgs
{
GemmArgs() = default;
GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
GroupedGemmBlock2ETileMap block_2_ctile_map,
index_t BlockStart,
index_t BlockEnd,
bool HasMainKBlockLoop)
: a_grid_desc_ak0_m_ak1_(a_grid_desc_ak0_m_ak1),
b_grid_desc_bk0_n_bk1_(b_grid_desc_bk0_n_bk1),
ds_grid_desc_mblock_mperblock_nblock_nperblock_(
ds_grid_desc_mblock_mperblock_nblock_nperblock),
e_grid_desc_mblock_mperblock_nblock_nperblock_(
e_grid_desc_mblock_mperblock_nblock_nperblock),
// block-to-e-tile map
block_2_ctile_map_(block_2_ctile_map),
BlockStart_(BlockStart),
BlockEnd_(BlockEnd),
HasMainKBlockLoop_(HasMainKBlockLoop)
{
}
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t BlockStart_, BlockEnd_;
bool HasMainKBlockLoop_;
};
using Block2TileMapInOutElementwise = BlockToCTileMap_M00_N0_M01Adapt<NPerBlock, MPerBlock>;
using Block2TileMapWeiElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
@@ -589,9 +673,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
index_t grid_size = 0;
// Allocate place for sets of gemms
gemm_kernel_args_.resize(
math::integer_divide_ceil(ZTilde * YTilde * XTilde, MaxGroupedGemmGroupsNum));
for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
{
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
{
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
@@ -694,36 +782,51 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
// desc for blockwise copy
a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1);
b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1);
const index_t grid_size_grp = Block2ETileMap::CalculateGridSize(
e_grid_desc_m_n.GetLength(I0), e_grid_desc_m_n.GetLength(I1));
// block-to-e-tile-map
auto block_2_etile_map =
GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
const index_t BlockStart = grid_size;
const index_t BlockEnd = grid_size + grid_size_grp;
block_2_etile_map_container_.push_back(block_2_etile_map);
grid_size += grid_size_grp;
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
block_2_etile_map,
k_batch_))
// block-to-e-tile map
const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(Block2ETileMap(e_grid_desc_m_n.GetLength(I0),
e_grid_desc_m_n.GetLength(I1)),
BlockStart);
const auto GemmK = a_grid_desc_m_k.GetLength(I1);
const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, k_batch_);
gemm_kernel_args_[gemms_count_ /
MaxGroupedGemmGroupsNum][gemms_count_ %
MaxGroupedGemmGroupsNum] =
GemmArgs{a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
GridwiseGemm::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n),
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n),
block_2_etile_map,
BlockStart,
BlockEnd,
HasMainKBlockLoop};
gemms_count_++;
if(gemms_count_ % MaxGroupedGemmGroupsNum == 0)
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
GridwiseGemm::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n));
e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n));
gemms_grid_size_.push_back(grid_size);
grid_size = 0;
}
}
}
}
gemm_kernel_args_.resize(
math::integer_divide_ceil(gemms_count_, MaxGroupedGemmGroupsNum));
gemms_grid_size_.push_back(grid_size);
// A/B/Ds/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides_transposed[0];
@@ -830,31 +933,28 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
void Print() const
{
for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++)
for(std::size_t i = 0; i < a_grid_desc_m_k_container_.size(); i++)
{
std::cout << "a_grid_desc_ak0_m_ak1_container_"
<< a_grid_desc_ak0_m_ak1_container_[i] << std::endl;
std::cout << "a_grid_desc_m_ak_container_" << a_grid_desc_m_k_container_[i]
<< std::endl;
std::cout << "b_grid_desc_bk0_n_bk1_container_"
<< b_grid_desc_bk0_n_bk1_container_[i] << std::endl;
std::cout << "b_grid_desc_n_bk_container_" << b_grid_desc_n_k_container_[i]
<< std::endl;
static_for<0, NumDTensor, 1>{}([&](auto j) {
std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_"
<< ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i][j]
<< std::endl;
<< ds_grid_desc_m_n_container_[i][j] << std::endl;
});
std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_"
<< e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i]
<< std::endl;
<< e_grid_desc_m_n_container_[i] << std::endl;
}
}
// pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
typename GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmMultiDTemplateParams>::DsGridPointer
p_ds_grid_;
typename GridwiseGemm::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// tensor descriptor for problem definition
@@ -865,16 +965,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
std::vector<EGridDesc_M_N> e_grid_desc_m_n_container_;
// tensor descriptor for block-wise copy
std::vector<AGridDesc_AK0_M_AK1> a_grid_desc_ak0_m_ak1_container_;
std::vector<BGridDesc_BK0_N_BK1> b_grid_desc_bk0_n_bk1_container_;
std::vector<DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_;
std::vector<EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
e_grid_desc_mblock_mperblock_nblock_nperblock_container_;
// block-to-e-tile map
std::vector<Block2ETileMap> block_2_etile_map_container_;
Block2TileMapInOutElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_e_;
Block2TileMapWeiElementwise elementwise_block_2_ctile_map_transpose_b_;
@@ -903,6 +994,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const index_t k_batch_;
index_t num_workgroups_per_Conv_N_;
std::vector<index_t> gemms_grid_size_;
index_t gemms_count_ = 0;
std::vector<std::array<GemmArgs, MaxGroupedGemmGroupsNum>> gemm_kernel_args_;
bool bwd_needs_zero_out;
long_index_t e_space_size_bytes;
};
@@ -941,84 +1036,61 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType);
}
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size();
gemm_set_id++)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
arg.b_grid_desc_n_k_container_[i],
arg.ds_grid_desc_m_n_container_[i],
arg.e_grid_desc_m_n_container_[i],
arg.block_2_etile_map_container_[i],
arg.k_batch_))
{
throw std::runtime_error("wrong! device_op has invalid setting");
}
const index_t gdx = arg.block_2_etile_map_container_[i].CalculateGridSize(
arg.e_grid_desc_m_n_container_[i]);
const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1);
const index_t gdx = arg.gemms_grid_size_[gemm_set_id];
const index_t gemms_count_for_set =
gemm_set_id == arg.gemm_kernel_args_.size() - 1
? arg.gemms_count_ - MaxGroupedGemmGroupsNum * gemm_set_id
: MaxGroupedGemmGroupsNum;
const std::array<GemmArgs, MaxGroupedGemmGroupsNum>& gemm_kernel_args =
arg.gemm_kernel_args_[gemm_set_id];
const auto clear_workspace = [&]() {
if(arg.bwd_needs_zero_out && i == 0)
if(arg.bwd_needs_zero_out && gemm_set_id == 0)
{
hip_check_error(hipMemsetAsync(
p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_));
}
};
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
auto launch_kernel = [&]() {
const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
typename GridwiseGemm::DsGridPointer,
EDataType,
MaxGroupedGemmGroupsNum,
GemmArgs,
AElementwiseOp,
BElementwiseOp,
CDEElementwiseOp,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
has_main_loop,
ElementOp>;
return launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_container_[i],
arg.b_grid_desc_bk0_n_bk1_container_[i],
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
arg.block_2_etile_map_container_[i],
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_,
arg.k_batch_);
return launch_and_time_kernel_with_preprocess(stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
arg.p_ds_grid_,
p_e_grid,
gemm_kernel_args,
gemms_count_for_set,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_n_,
arg.k_batch_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, arg.k_batch_))
{
ave_time += launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time += launch_kernel(integral_constant<bool, false>{});
}
ave_time += launch_kernel();
}
return ave_time;
@@ -1304,14 +1376,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
// Gridwise GEMM size
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
arg.b_grid_desc_n_k_container_[i],
arg.ds_grid_desc_m_n_container_[i],
arg.e_grid_desc_m_n_container_[i],
arg.block_2_etile_map_container_[i],
arg.k_batch_))
if(!GridwiseGemm::CheckValidity(
arg.a_grid_desc_m_k_container_[i],
arg.b_grid_desc_n_k_container_[i],
arg.ds_grid_desc_m_n_container_[i],
arg.e_grid_desc_m_n_container_[i],
arg.gemm_kernel_args_[i / MaxGroupedGemmGroupsNum][i % MaxGroupedGemmGroupsNum]
.block_2_ctile_map_,
arg.k_batch_))
{
return false;
}

View File

@@ -271,6 +271,7 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M,
index_t N,
index_t M01 = 8)
@@ -870,6 +871,7 @@ struct OffsettedBlockToCTileMap
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__ OffsettedBlockToCTileMap() = default;
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t block_start)
{

View File

@@ -186,8 +186,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
rtol = std::max(rtol, rtol_split_k);
atol = std::max(atol, atol_split_k);
pass = pass & ck::utils::check_err(
in_device, in_host, "Error: Incorrect results!", rtol, atol);
pass &= ck::utils::check_err(
in_device, in_host, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol
<< " Absolute error threshold: " << atol << std::endl;

View File

@@ -261,8 +261,9 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
ck::utils::get_absolute_threshold<WeiDataType, WeiDataType, WeiDataType>(
max_accumulated_value, num_accums_split_k);
// Use higher threshold
rtol = std::max(rtol, rtol_split_k);
atol = std::max(atol, atol_split_k);
rtol = std::max(rtol, rtol_split_k);
atol = std::max(atol, atol_split_k);
// Use default atol for splitK == 1
bool pass = ck::utils::check_err(weight_device_result,
weight_host_result,
"Error: Incorrect results!",

View File

@@ -96,6 +96,18 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D)
{
this->conv_params.clear();
// GroupedGemmGroupsNum = 4, ZTilde * YTilde * XTilde = 4, MaxGroupedGemmGroupsNum = 32
this->conv_params.push_back(
{2, 2, 2, 16, 16, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}});
// GroupedGemmGroupsNum = 9, ZTilde * YTilde * XTilde = 36, MaxGroupedGemmGroupsNum = 32
this->conv_params.push_back(
{2, 2, 2, 16, 16, {3, 3}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1}});
// GroupedGemmGroupsNum = 36, ZTilde * YTilde * XTilde = 36, MaxGroupedGemmGroupsNum = 32
this->conv_params.push_back(
{2, 2, 2, 16, 16, {6, 6}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1}});
// GroupedGemmGroupsNum = 32, ZTilde * YTilde * XTilde = 32, MaxGroupedGemmGroupsNum = 32
this->conv_params.push_back(
{2, 2, 2, 16, 16, {4, 8}, {28, 28}, {4, 8}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 2, 2, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(