update shader

This commit is contained in:
joye
2025-06-04 12:21:38 +08:00
parent 37555a8f66
commit 075527783c

View File

@@ -175,7 +175,11 @@ __global__ void
// grp3 grp4 grp5 grp6 grp7
// load weight: 64*5*5 half needs 200 requests when use float4
template <typename ABDataType, typename EDataType, index_t GroupPerBlock, index_t BatchPerBlock>
template <typename ABDataType,
typename EDataType,
index_t GroupPerBlock,
index_t BatchPerBlock,
index_t BlockSize>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(512, CK_MIN_BLOCK_PER_CU)
@@ -207,7 +211,8 @@ __global__ void
const int pad_width = pad_x;
constexpr int batch_num = BatchPerBlock;
static_assert(batch_num == BlockSize / warpSize,
"currently only support one wave implement one batch");
// NHWGK todo use the stride
const int outgrad_group_stride = 1;
const int outgrad_batch_stride = group_num * out_height * out_width;
@@ -223,8 +228,7 @@ __global__ void
constexpr index_t ElementPerFP4 = 16 / sizeof(ABDataType);
#pragma unroll
for(int index = tid; index < 64 * 5 * 5 / ElementPerFP4; index += blockDim.x)
for(int index = tid; index < 64 * 5 * 5 / ElementPerFP4; index += BlockSize)
{
reinterpret_cast<float4*>(shmem_weight)[index] =
reinterpret_cast<const float4*>(weight_ptr)[index];
@@ -232,100 +236,94 @@ __global__ void
block_sync_lds();
int batch_iter = blockDim.x / warpSize;
int local_grp_id = tid % warpSize;
int glb_batch_offset = (blockIdx.x % blockNumPerGroup) * BatchPerBlock;
int wave_id = __builtin_amdgcn_readfirstlane(tid / warpSize);
for(int batch_idx = 0; batch_idx < batch_num; batch_idx += batch_iter)
int batch_id_in_glb_mem = wave_id + glb_batch_offset;
const int base_outgrad_offset = (grp_idx + local_grp_id) * outgrad_group_stride +
batch_id_in_glb_mem * outgrad_batch_stride;
const int base_ingrad_offset =
(grp_idx + local_grp_id) * ingrad_group_stride + batch_id_in_glb_mem * ingrad_batch_stride;
const int base_filter_offset = local_grp_id * 25;
for(int h_idx = 0; h_idx < in_height; h_idx += 2)
{
int batch_id_in_glb_mem = wave_id + batch_idx + glb_batch_offset;
const int base_outgrad_offset = (grp_idx + local_grp_id) * outgrad_group_stride +
batch_id_in_glb_mem * outgrad_batch_stride;
const int base_ingrad_offset = (grp_idx + local_grp_id) * ingrad_group_stride +
batch_id_in_glb_mem * ingrad_batch_stride;
const int base_filter_offset = local_grp_id * 25;
for(int h_idx = 0; h_idx < in_height; h_idx += 2)
for(int w_idx = 0; w_idx < in_width; w_idx += 2)
{
for(int w_idx = 0; w_idx < in_width; w_idx += 2)
float sum[4]{0.f};
const int out_row_start_0 = __builtin_amdgcn_readfirstlane(
max(0, (h_idx - filter_height + pad_height + stride_y) / stride_y));
const int out_row_end_0 = __builtin_amdgcn_readfirstlane(
min(out_height - 1, (h_idx + pad_height) / stride_y));
const int out_row_start_1 = __builtin_amdgcn_readfirstlane(
max(0, (h_idx + 1 - filter_height + pad_height + stride_y) / stride_y));
const int out_row_end_1 = __builtin_amdgcn_readfirstlane(
min(out_height - 1, (h_idx + 1 + pad_height) / stride_y));
const int out_col_start_0 = __builtin_amdgcn_readfirstlane(
max(0, (w_idx - filter_width + pad_width + stride_x) / stride_x));
const int out_col_start_1 = __builtin_amdgcn_readfirstlane(
max(0, (w_idx + 1 - filter_width + pad_width + stride_x) / stride_x));
const int out_col_end_0 =
__builtin_amdgcn_readfirstlane(min(out_width - 1, (w_idx + pad_width) / stride_x));
const int out_col_end_1 = __builtin_amdgcn_readfirstlane(
min(out_width - 1, (w_idx + 1 + pad_width) / stride_x));
for(int out_row = out_row_start_0; out_row <= out_row_end_1; ++out_row)
{
float sum[4]{0.f};
const int out_row_start_0 = __builtin_amdgcn_readfirstlane(
max(0, (h_idx - filter_height + pad_height + stride_y) / stride_y));
const int out_row_end_0 = __builtin_amdgcn_readfirstlane(
min(out_height - 1, (h_idx + pad_height) / stride_y));
const int out_row_start_1 = __builtin_amdgcn_readfirstlane(
max(0, (h_idx + 1 - filter_height + pad_height + stride_y) / stride_y));
const int out_row_end_1 = __builtin_amdgcn_readfirstlane(
min(out_height - 1, (h_idx + 1 + pad_height) / stride_y));
const int out_col_start_0 = __builtin_amdgcn_readfirstlane(
max(0, (w_idx - filter_width + pad_width + stride_x) / stride_x));
const int out_col_start_1 = __builtin_amdgcn_readfirstlane(
max(0, (w_idx + 1 - filter_width + pad_width + stride_x) / stride_x));
const int out_col_end_0 = __builtin_amdgcn_readfirstlane(
min(out_width - 1, (w_idx + pad_width) / stride_x));
const int out_col_end_1 = __builtin_amdgcn_readfirstlane(
min(out_width - 1, (w_idx + 1 + pad_width) / stride_x));
for(int out_row = out_row_start_0; out_row <= out_row_end_1; ++out_row)
const int filter_row_0 =
__builtin_amdgcn_readfirstlane(h_idx + pad_height - out_row * stride_y);
const int filter_row_1 =
__builtin_amdgcn_readfirstlane(h_idx + 1 + pad_height - out_row * stride_y);
for(int out_col = out_col_start_0; out_col <= out_col_end_1; ++out_col)
{
const int filter_row_0 =
__builtin_amdgcn_readfirstlane(h_idx + pad_height - out_row * stride_y);
const int filter_row_1 =
__builtin_amdgcn_readfirstlane(h_idx + 1 + pad_height - out_row * stride_y);
for(int out_col = out_col_start_0; out_col <= out_col_end_1; ++out_col)
{
const int filter_col_0 =
__builtin_amdgcn_readfirstlane(w_idx + pad_width - out_col * stride_x);
const int filter_col_1 = __builtin_amdgcn_readfirstlane(
w_idx + 1 + pad_width - out_col * stride_x);
const int outgrad_offset = base_outgrad_offset +
out_row * outgrad_row_stride +
out_col * outgrad_col_stride;
const int filter_col_0 =
__builtin_amdgcn_readfirstlane(w_idx + pad_width - out_col * stride_x);
const int filter_col_1 =
__builtin_amdgcn_readfirstlane(w_idx + 1 + pad_width - out_col * stride_x);
const int outgrad_offset = base_outgrad_offset + out_row * outgrad_row_stride +
out_col * outgrad_col_stride;
ABDataType gradOut = p_gradOut[outgrad_offset];
ABDataType gradOut = p_gradOut[outgrad_offset];
bool row_in_axis0 = (out_row <= out_row_end_0);
bool row_in_axis1 = (out_row >= out_row_start_1);
bool col_in_axis0 = (out_col <= out_col_end_0);
bool col_in_axis1 = (out_col >= out_col_start_1);
bool row_in_axis0 = (out_row <= out_row_end_0);
bool row_in_axis1 = (out_row >= out_row_start_1);
bool col_in_axis0 = (out_col <= out_col_end_0);
bool col_in_axis1 = (out_col >= out_col_start_1);
sum[0] += ((row_in_axis0 && col_in_axis0)
? shmem_weight[base_filter_offset + filter_row_0 * 5 +
filter_col_0] *
gradOut
: 0.f);
sum[1] += ((row_in_axis0 && col_in_axis1)
? shmem_weight[base_filter_offset + filter_row_0 * 5 +
filter_col_1] *
gradOut
: 0.f);
sum[2] += ((row_in_axis1 && col_in_axis0)
? shmem_weight[base_filter_offset + filter_row_1 * 5 +
filter_col_0] *
gradOut
: 0.f);
sum[3] += ((row_in_axis1 && col_in_axis1)
? shmem_weight[base_filter_offset + filter_row_1 * 5 +
filter_col_1] *
gradOut
: 0.f);
}
sum[0] +=
((row_in_axis0 && col_in_axis0)
? shmem_weight[base_filter_offset + filter_row_0 * 5 + filter_col_0] *
gradOut
: 0.f);
sum[1] +=
((row_in_axis0 && col_in_axis1)
? shmem_weight[base_filter_offset + filter_row_0 * 5 + filter_col_1] *
gradOut
: 0.f);
sum[2] +=
((row_in_axis1 && col_in_axis0)
? shmem_weight[base_filter_offset + filter_row_1 * 5 + filter_col_0] *
gradOut
: 0.f);
sum[3] +=
((row_in_axis1 && col_in_axis1)
? shmem_weight[base_filter_offset + filter_row_1 * 5 + filter_col_1] *
gradOut
: 0.f);
}
}
#pragma unroll
for(int i = 0; i < 2; i++)
for(int i = 0; i < 2; i++)
{
#pragma unroll
for(int j = 0; j < 2; j++)
{
#pragma unroll
for(int j = 0; j < 2; j++)
{
const int output_offset = base_ingrad_offset +
(h_idx + i) * ingrad_row_stride +
(w_idx + j) * ingrad_col_stride;
p_gradIn[output_offset] = __float2half(sum[i * 2 + j]);
}
const int output_offset = base_ingrad_offset + (h_idx + i) * ingrad_row_stride +
(w_idx + j) * ingrad_col_stride;
p_gradIn[output_offset] = __float2half(sum[i * 2 + j]);
}
}
}
@@ -1117,15 +1115,17 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// constexpr bool has_main_loop = has_main_k_block_loop.value;
constexpr index_t GroupPerBlock = 64;
constexpr index_t BatchPerBlock = 8;
constexpr index_t BlockDim = 512;
const auto kernel = kernel_grouped_conv_bwd_data_optimized<ADataType,
EDataType,
GroupPerBlock,
BatchPerBlock>;
BatchPerBlock,
BlockDim>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(1344 / GroupPerBlock * 16, 1, 1),
dim3(512),
dim3(BlockDim),
0,
p_a_grid,
p_b_grid,