update kernel

This commit is contained in:
joyeamd
2025-06-09 19:37:44 +08:00
parent 0677989d23
commit b1d03f7e8a

View File

@@ -74,28 +74,28 @@ template <typename GridwiseGemm,
InMemoryDataOperationEnum OutElementOp>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOp a_element_op,
const BElementwiseOp b_element_op,
const CDEElementwiseOp cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const ComputePtrOffsetOfN compute_ptr_offset_of_n,
const index_t KBatch)
kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle(
const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOp a_element_op,
const BElementwiseOp b_element_op,
const CDEElementwiseOp cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const ComputePtrOffsetOfN compute_ptr_offset_of_n,
const index_t KBatch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
#if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch);
@@ -394,22 +394,30 @@ enum DepthwiseConv2dDirection
DIRECTION_BACKWARD
};
template <typename Argument,
template <typename ABDataType,
typename EDataType,
DepthwiseConv2dDirection direction,
index_t BlockSize,
index_t BatchPerBlock,
index_t GroupPerBlock,
index_t TileOutW, // output tile width; this is the tile size in the gradientIn
index_t TileOutH, // output tile height
index_t up_w,
index_t up_h,
index_t down_w,
index_t up_w,
index_t down_h,
index_t BlockSize>
__global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg)
index_t down_w,
index_t pad_h,
index_t pad_w>
__global__ void kernel_grouped_conv_bwd_data_optimized_v2(const ABDataType* __restrict__ p_gradOut,
const ABDataType* __restrict__ p_weight,
EDataType* __restrict__ p_gradIn,
const index_t out_width,
const index_t out_height,
const index_t in_width,
const index_t in_height,
const index_t group_num,
const index_t whole_batch_num)
{
using ABDataType = typename Argument::ABDataType;
using EDataType = typename Argument::EDataType;
constexpr index_t ElementPerInFP4 = 16 / sizeof(ABDataType);
constexpr index_t ElementPerOutFP4 = 16 / sizeof(EDataType);
@@ -418,14 +426,12 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg)
static_assert(GroupPerBlock == 32, "Currently only support GroupPerBlock == 32");
static_assert(BatchPerBlock % WaveNum == 0,
"Currently BatchPerBlock should be dividable by WaveNum");
constexpr index_t BatchIter = BatchPerBlock / WaveNum;
constexpr index_t TileInW = ((TileOutW - 1) * down_w + kernelW - 1) / up_w + 1;
constexpr index_t TileInH = ((TileOutH - 1) * down_h + kernelH - 1) / up_h + 1;
// WaveNum * GroupPerBlk * TileInH * TileInW
constexpr index_t GroupPerBlockInFP4 = GroupPerBlock / ElementPerInFP4;
constexpr index_t WaveNum = BlockSize / warpSize;
__shared__ volatile ABDataType shmem_k[kernelH * kernelW * GroupPerBlock]; // layout : H->W->G
__shared__ volatile ABDataType shmem_x[WaveNum * TileInH * TileInW * GroupPerBlock];
// layout : B->H->W->G will use double buffer to go through BatchPerBlock
@@ -513,11 +519,11 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg)
rel_in_h * TileInW * GroupPerBlockInFP4 +
rel_in_w * GroupPerBlockInFP4 + group_id * ElementPerInFP4;
bool is_in_bound = (in_x >= 0 && in_x < inWidth) && (in_y >= 0 && in_y < inHeight);
float4_t v_0{0.f, 0.f, 0.f, 0.f};
bool is_in_bound = (in_x >= 0 && in_x < in_width) && (in_y >= 0 && in_y < in_height);
float4_t v{0.f, 0.f, 0.f, 0.f};
if(is_in_bound)
{
v_0 = reinterpret_cast<const float4_t*>(
v = reinterpret_cast<const float4_t*>(
arg.p_a_grid_)[ingrad_offset / ElementPerInFP4];
}
@@ -577,7 +583,7 @@ __global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg)
int outgrad_offset =
(group_start_id_per_blk + group_out_id * ElementPerInFP4) *
outgrad_group_stride +
(batch_start_id_per_blk + batch_id_per_wave) * outgrad_batch_stride +
(batch_start_id_per_blk + batch_id + batch_id_per_wave) * outgrad_batch_stride +
out_y * outgrad_row_stride + out_x * outgrad_col_stride;
reinterpret_cast<ETypeDstVec_t*>(p_gradOut)[outgrad_offset / ElementPerFP4] =
@@ -1386,17 +1392,31 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
{
if(stride_y == 1 && stride_x == 1 && pad_y == 1 && pad_x == 1)
{
return kernel_grouped_conv_bwd_data_optimized<ADataType,
EDataType,
GroupPerBlock,
BatchPerBlock,
BlockDim,
3,
3,
1,
1,
1,
1>;
// return kernel_grouped_conv_bwd_data_optimized<ADataType,
// EDataType,
// GroupPerBlock,
// BatchPerBlock,
// BlockDim,
// 3,
// 3,
// 1,
// 1,
// 1,
// 1>;
return kernel_grouped_conv_bwd_data_optimized_v2<ADataType,
EDataType,
DIRECTION_BACKWARD,
BlockDim,
BatchPerBlock,
GroupPerBlock,
4,
4,
1,
1,
1,
1,
1,
1>;
}
else if(stride_y == 2 && stride_x == 2 && pad_y == 1 && pad_x == 1)
{