mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
update kernel
This commit is contained in:
@@ -74,28 +74,28 @@ template <typename GridwiseGemm,
|
||||
InMemoryDataOperationEnum OutElementOp>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle(
|
||||
const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOp a_element_op,
|
||||
const BElementwiseOp b_element_op,
|
||||
const CDEElementwiseOp cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const ComputePtrOffsetOfN compute_ptr_offset_of_n,
|
||||
const index_t KBatch)
|
||||
kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle(
|
||||
const ABDataType* __restrict__ p_a_grid,
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOp a_element_op,
|
||||
const BElementwiseOp b_element_op,
|
||||
const CDEElementwiseOp cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const ComputePtrOffsetOfN compute_ptr_offset_of_n,
|
||||
const index_t KBatch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch);
|
||||
@@ -394,22 +394,30 @@ enum DepthwiseConv2dDirection
|
||||
DIRECTION_BACKWARD
|
||||
};
|
||||
|
||||
template <typename Argument,
|
||||
template <typename ABDataType,
|
||||
typename EDataType,
|
||||
DepthwiseConv2dDirection direction,
|
||||
index_t BlockSize,
|
||||
index_t BatchPerBlock,
|
||||
index_t GroupPerBlock,
|
||||
index_t TileOutW, // output tile width; this is the tile size in the gradientIn
|
||||
index_t TileOutH, // output tile height
|
||||
index_t up_w,
|
||||
index_t up_h,
|
||||
index_t down_w,
|
||||
index_t up_w,
|
||||
index_t down_h,
|
||||
index_t BlockSize>
|
||||
__global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg)
|
||||
index_t down_w,
|
||||
index_t pad_h,
|
||||
index_t pad_w>
|
||||
__global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __restrict__ p_gradOut,
|
||||
const ABDataType* __restrict__ p_weight,
|
||||
EDataType* __restrict__ p_gradIn,
|
||||
const index_t out_width,
|
||||
const index_t out_height,
|
||||
const index_t in_width,
|
||||
const index_t in_height,
|
||||
const index_t group_num,
|
||||
const index_t whole_batch_num)
|
||||
{
|
||||
using ABDataType = typename Argument::ABDataType;
|
||||
using EDataType = typename Argument::EDataType;
|
||||
|
||||
constexpr index_t ElementPerInFP4 = 16 / sizeof(ABDataType);
|
||||
constexpr index_t ElementPerOutFP4 = 16 / sizeof(EDataType);
|
||||
|
||||
@@ -418,14 +426,12 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg)
|
||||
static_assert(GroupPerBlock == 32, "Currently only support GroupPerBlock == 32");
|
||||
static_assert(BatchPerBlock % WaveNum == 0,
|
||||
"Currently BatchPerBlock should be dividable by WaveNum");
|
||||
constexpr index_t BatchIter = BatchPerBlock / WaveNum;
|
||||
|
||||
constexpr index_t TileInW = ((TileOutW - 1) * down_w + kernelW - 1) / up_w + 1;
|
||||
constexpr index_t TileInH = ((TileOutH - 1) * down_h + kernelH - 1) / up_h + 1;
|
||||
// WaveNum * GroupPerBlk * TileInH * TileInW
|
||||
constexpr index_t GroupPerBlockInFP4 = GroupPerBlock / ElementPerInFP4;
|
||||
|
||||
constexpr index_t WaveNum = BlockSize / warpSize;
|
||||
__shared__ volatile ABDataType shmem_k[kernelH * kernelW * GroupPerBlock]; // layout : H->W->G
|
||||
__shared__ volatile ABDataType shmem_x[WaveNum * TileInH * TileInW * GroupPerBlock];
|
||||
// layout : B->H->W->G will use double buffer to go through BatchPerBlock
|
||||
@@ -513,11 +519,11 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg)
|
||||
rel_in_h * TileInW * GroupPerBlockInFP4 +
|
||||
rel_in_w * GroupPerBlockInFP4 + group_id * ElementPerInFP4;
|
||||
|
||||
bool is_in_bound = (in_x >= 0 && in_x < inWidth) && (in_y >= 0 && in_y < inHeight);
|
||||
float4_t v_0{0.f, 0.f, 0.f, 0.f};
|
||||
bool is_in_bound = (in_x >= 0 && in_x < in_width) && (in_y >= 0 && in_y < in_height);
|
||||
float4_t v{0.f, 0.f, 0.f, 0.f};
|
||||
if(is_in_bound)
|
||||
{
|
||||
v_0 = reinterpret_cast<const float4_t*>(
|
||||
v = reinterpret_cast<const float4_t*>(
|
||||
arg.p_a_grid_)[ingrad_offset / ElementPerInFP4];
|
||||
}
|
||||
|
||||
@@ -577,7 +583,7 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg)
|
||||
int outgrad_offset =
|
||||
(group_start_id_per_blk + group_out_id * ElementPerInFP4) *
|
||||
outgrad_group_stride +
|
||||
(batch_start_id_per_blk + batch_id_per_wave) * outgrad_batch_stride +
|
||||
(batch_start_id_per_blk + batch_id + batch_id_per_wave) * outgrad_batch_stride +
|
||||
out_y * outgrad_row_stride + out_x * outgrad_col_stride;
|
||||
|
||||
reinterpret_cast<ETypeDstVec_t*>(p_gradOut)[outgrad_offset / ElementPerFP4] =
|
||||
@@ -1386,17 +1392,31 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
{
|
||||
if(stride_y == 1 && stride_x == 1 && pad_y == 1 && pad_x == 1)
|
||||
{
|
||||
return kernel_grouped_conv_bwd_data_optimized<ADataType,
|
||||
EDataType,
|
||||
GroupPerBlock,
|
||||
BatchPerBlock,
|
||||
BlockDim,
|
||||
3,
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1>;
|
||||
// return kernel_grouped_conv_bwd_data_optimized<ADataType,
|
||||
// EDataType,
|
||||
// GroupPerBlock,
|
||||
// BatchPerBlock,
|
||||
// BlockDim,
|
||||
// 3,
|
||||
// 3,
|
||||
// 1,
|
||||
// 1,
|
||||
// 1,
|
||||
// 1>;
|
||||
return kernel_grouped_conv_bwd_data_optimized_v2<ADataType,
|
||||
EDataType,
|
||||
DIRECTION_BACKWARD,
|
||||
BlockDim,
|
||||
BatchPerBlock,
|
||||
GroupPerBlock,
|
||||
4,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1>;
|
||||
}
|
||||
else if(stride_y == 2 && stride_x == 2 && pad_y == 1 && pad_x == 1)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user