mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
update kernel; not correctly
This commit is contained in:
@@ -369,12 +369,196 @@ __global__ void
|
||||
}
|
||||
}
|
||||
|
||||
// template <typename Argument>
|
||||
// __global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg)
|
||||
// {
|
||||
/*
|
||||
one wave implement one batch 32 group; one group implements tile_x * tile_y output; what about the
|
||||
size not divided by 4?
|
||||
|
||||
// }
|
||||
one wave implements 32 x 4 x 4; one thread can fetch 8 groups h0w0 data; 4 thread will 32 groups'
|
||||
h0w0; 64 thread can load 4x4;
|
||||
|
||||
tid 0 and 32 works on group 0; tid 1 and 33 works on group 1; tid 2 and 34 works on
|
||||
group 2; tid 3 and 35 works on group 3; etcs.
|
||||
|
||||
wave 1 in the same block goes through the batch direction
|
||||
gridDim(ceiling(InWidth / TileOutW), ceiling(InHeight / TileOutH), (WholeBatchNum / BatchPerBlock)
|
||||
* (WholeGroupNum / GroupPerBlk))
|
||||
|
||||
BlockDim(warpSize * warpNum, 1, 1)
|
||||
when foward, up means dilate, down means stride
|
||||
when backward, up means stride, down means dilate
|
||||
*/
|
||||
enum DepthwiseConv2dDirection
|
||||
{
|
||||
DIRECTION_FORWARD,
|
||||
DIRECTION_BACKWARD
|
||||
};
|
||||
|
||||
template <typename Argument,
|
||||
DepthwiseConv2dDirection direction,
|
||||
index_t BatchPerBlk,
|
||||
index_t GroupPerBlk,
|
||||
index_t TileOutW, // output tile width; this is the tile size in the gradientIn
|
||||
index_t TileOutH, // output tile height
|
||||
index_t up_w,
|
||||
index_t up_h,
|
||||
index_t down_w,
|
||||
index_t down_h,
|
||||
index_t BlockSize>
|
||||
__global__ void kernel_grouped_conv_bwd_data_optimized_v2(Argument& arg)
|
||||
{
|
||||
using ABDataType = typename Argument::ABDataType;
|
||||
using EDataType = typename Argument::EDataType;
|
||||
|
||||
constexpr index_t ElementPerFP4 = 16 / sizeof(ABDataType);
|
||||
|
||||
static_assert(GroupPerBlk == 32, "Currently only support GroupPerWave == 32");
|
||||
constexpr index_t TileInW = ((TileOutW - 1) * down_w + kernelW - 1) / up_w + 1;
|
||||
constexpr index_t TileInH = ((TileOutH - 1) * down_h + kernelH - 1) / up_h + 1;
|
||||
|
||||
constexpr index_t WaveNum = BlockSize / warpSize;
|
||||
__shared__ volatile ABDataType shmem_k[GroupPerBlk * kernelH * kernelW]; // layout : H->W->G
|
||||
__shared__ volatile ABDataType shmem_x[2][WaveNum * GroupPerBlk * TileInH * TileInW];
|
||||
// layout : B->H->W->G will use double buffer to go through BatchPerBlock
|
||||
// when backward data, will be gradOut, when forward, will be x
|
||||
|
||||
const int output_tile_w = blockIdx.x * TileOutW;
|
||||
const int output_tile_h = blockIdx.y * TileOutH;
|
||||
|
||||
if(output_tile_w >= outWidth || output_tile_h >= outHeight)
|
||||
{
|
||||
return; // out of bound
|
||||
}
|
||||
|
||||
const int GroupBatchNum = WholeGroupNum / GroupPerBlk;
|
||||
|
||||
// NHWGK todo use the stride
|
||||
const int outgrad_group_stride = 1;
|
||||
const int outgrad_batch_stride = group_num * out_height * out_width;
|
||||
const int outgrad_row_stride = group_num * out_width;
|
||||
const int outgrad_col_stride = group_num;
|
||||
// NHWGC
|
||||
const int ingrad_group_stride = 1;
|
||||
const int ingrad_batch_stride = group_num * in_height * in_width;
|
||||
const int ingrad_row_stride = group_num * in_width;
|
||||
const int ingrad_col_stride = group_num;
|
||||
|
||||
int tid = threadIdx.x;
|
||||
|
||||
// this offset is used to calculate the start offset of the group and batch
|
||||
const int group_start_id_per_blk = (blockIdx.z % GroupBatchNum) * GroupPerBlk;
|
||||
const int batch_start_id_per_blk = (blockIdx.z / GroupBatchNum) * BatchPerBlk;
|
||||
|
||||
int wave_id = __builtin_amdgcn_readfirstlane(tid / warpSize);
|
||||
|
||||
int tile_mid_w = tileOutW * down_w + up_w - 1 - pad_w;
|
||||
int tile_mid_h = tileOutH * down_h + up_h - 1 - pad_h;
|
||||
int tile_in_x = tile_mid_w / up_w;
|
||||
int tile_in_y = tile_mid_h / up_h;
|
||||
|
||||
// WaveNum * GroupPerBlk * TileInH * TileInW
|
||||
constexpr index_t GroupPerBlockInFP4 = GroupPerBlk / ElementPerFP4;
|
||||
|
||||
int group_id = tid % GroupPerBlockInFP4;
|
||||
int rel_in_w = (tid / GroupPerBlockInFP4) % TileInW;
|
||||
int rel_in_h = (tid / (GroupPerBlockInFP4 * TileInW)) % TileInH;
|
||||
|
||||
int in_x = rel_in_w + tile_in_x;
|
||||
int in_y = rel_in_h + tile_in_y;
|
||||
|
||||
int local_batch_id = wave_id;
|
||||
int ingrad_offset = (group_start_id_per_blk + group_id * ElementPerFP4) * ingrad_group_stride +
|
||||
(batch_start_id_per_blk + local_batch_id) * ingrad_batch_stride +
|
||||
in_y * ingrad_row_stride + in_x * ingrad_col_stride;
|
||||
|
||||
int shmem_offset = wave_id * GroupPerBlockInFP4 * TileInH * TileInW +
|
||||
rel_in_h * TileInW * GroupPerBlockInFP4 + rel_in_w * GroupPerBlockInFP4 +
|
||||
group_id * ElementPerFP4;
|
||||
|
||||
bool is_in_bound = (in_x >= 0 && in_x < inWidth) && (in_y >= 0 && in_y < inHeight);
|
||||
|
||||
// static_assert(
|
||||
// WaveNum * GroupPerBlockInFP4 * TileInH * TileInW % BlockSize == 0,
|
||||
// "WaveNum * GroupPerBlockInFP4 * TileInH * TileInW must be divisible by BlockSize");
|
||||
constexpr int InLoopNum = WaveNum * GroupPerBlockInFP4 * TileInH * TileInW;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = tid; i < InLoopNum; i += BlockSize)
|
||||
{
|
||||
float4_t v_0{0.f, 0.f, 0.f, 0.f};
|
||||
if(is_in_bound)
|
||||
{
|
||||
v_0 = reinterpret_cast<const float4_t*>(p_in)[ingrad_offset / ElementPerFP4];
|
||||
}
|
||||
|
||||
reinterpret_cast<float4_t*>(shmem_x)[shmem_offset] = v;
|
||||
}
|
||||
|
||||
// load weight to shared memory
|
||||
// global weight layout : GKYXC; shared weight layout : Y->X->G
|
||||
for(int i = tid; i < GroupPerBlk * kernelH * kernelW; i += BlockSize)
|
||||
{
|
||||
int local_group_id = i / (kernelH * kernelW);
|
||||
int glb_group_id = local_group_id + group_start_id_per_blk;
|
||||
int kernel_h = (i % (kernelH * kernelW)) / kernelW;
|
||||
int kernel_w = i % kernelW;
|
||||
|
||||
shmem_k[kernel_h * kernelW * GroupPerBlk + kernel_w * GroupPerBlk + local_group_id] =
|
||||
p_weight[glb_group_id * kernelH * kernelW + kernel_h * kernelW + kernel_w];
|
||||
}
|
||||
|
||||
int ping = 0;
|
||||
for(int i = 1; i < BatchPerBlk; i++)
|
||||
{
|
||||
ingrad_offset += ingrad_batch_stride / ElementPerFP4;
|
||||
|
||||
block_sync_lds();
|
||||
#pragma unroll
|
||||
for(int i = tid; i < InLoopNum; i += BlockSize)
|
||||
{
|
||||
float4_t v{0.f, 0.f, 0.f, 0.f};
|
||||
if(is_in_bound)
|
||||
{
|
||||
v = reinterpret_cast<const float4*>(p_in)[ingrad_offset / ElementPerFP4];
|
||||
}
|
||||
reinterpret_cast<float4*>(shmem_x[1 - ping])[shmem_offset] = v;
|
||||
}
|
||||
constexpr int OutLoopNum = WaveNum * GroupPerBlockInFP4 * TileOutW * TileOutH;
|
||||
|
||||
for(int out_idx = tid; out_idx < OutLoopNum; out_idx += BlockSize)
|
||||
{
|
||||
int rel_out_y = out_idx / TileOutW;
|
||||
int rel_out_x = out_idx - rel_out_y * TileOutW;
|
||||
int out_y = rel_out_y + tile_out_y;
|
||||
int out_x = rel_out_x + tile_out_x;
|
||||
|
||||
int mid_x = tile_mid_x + rel_out_x * down_x;
|
||||
int mid_y = tile_mid_y + rel_out_y * down_y;
|
||||
int in_x = floor_div(mid_x, up_x);
|
||||
int in_y = floor_div(mid_y, up_y);
|
||||
int rel_in_x = in_x - tile_in_x;
|
||||
int rel_in_y = in_y - tile_in_y;
|
||||
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
||||
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
||||
if(in_x < 0 || in_x >= inWidth || in_y < 0 || in_y >= inHeight)
|
||||
continue;
|
||||
|
||||
#pragma unroll
|
||||
for(int y = 0; y < kernel_h / up_y; y++)
|
||||
#pragma unroll
|
||||
for(int x = 0; x < kernel_w / up_x; x++)
|
||||
{
|
||||
v += sx[rel_in_y + y][rel_in_x + x] *
|
||||
sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
||||
}
|
||||
|
||||
if(out_x < p.out_w & out_y < p.out_h)
|
||||
{
|
||||
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) + minor_idx] = v;
|
||||
}
|
||||
}
|
||||
|
||||
ping = 1 - ping;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Conv backward data multiple D:
|
||||
@@ -1151,7 +1335,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
throw std::runtime_error("wrong! device_op has invalid setting");
|
||||
}
|
||||
|
||||
// const index_t gdx = arg.block_2_etile_map_container_[i].CalculateGridSize(
|
||||
// const index_t gdx =
|
||||
// arg.block_2_etile_map_container_[i].CalculateGridSize(
|
||||
// arg.e_grid_desc_m_n_container_[i]);
|
||||
|
||||
// const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1);
|
||||
@@ -1192,7 +1377,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
arg.a_g_n_k_wos_lengths_[1],
|
||||
arg.b_g_k_c_xs_lengths_[NDimSpatial + 1] *
|
||||
arg.b_g_k_c_xs_lengths_[NDimSpatial + 2]);
|
||||
// const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
|
||||
// const auto kernel =
|
||||
// kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
|
||||
// GridwiseGemm,
|
||||
// ADataType, // TODO: distiguish A/B datatype
|
||||
// typename GridwiseGemm::DsGridPointer,
|
||||
@@ -1571,11 +1757,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout
|
||||
<< "Warning: Workspace for "
|
||||
"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument is not "
|
||||
"allocated, use SetWorkSpacePointer."
|
||||
<< std::endl;
|
||||
std::cout << "Warning: Workspace for "
|
||||
"DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::"
|
||||
"Argument is not "
|
||||
"allocated, use SetWorkSpacePointer."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user