update variable to not hard coding

This commit is contained in:
joye
2025-06-05 15:09:35 +08:00
parent 69b6a8b20c
commit 65df6b65ed

View File

@@ -197,12 +197,14 @@ __global__ void
const index_t out_height,
const index_t in_width,
const index_t in_height,
const index_t group_num)
const index_t group_num,
const index_t whole_batch_num,
const index_t filter_size)
{
constexpr int blockNumPerGroup = 128 / BatchPerBlock;
int grp_idx = GroupPerBlock * (blockIdx.x / blockNumPerGroup);
const ABDataType* weight_ptr = p_weight + grp_idx * 25;
int tid = threadIdx.x;
const int blockNumPerGroup = whole_batch_num / BatchPerBlock;
int grp_idx = GroupPerBlock * (blockIdx.x / blockNumPerGroup);
const ABDataType* weight_ptr = p_weight + grp_idx * filter_size;
int tid = threadIdx.x;
const int filter_height = filter_y;
const int filter_width = filter_x;
@@ -224,11 +226,11 @@ __global__ void
const int ingrad_row_stride = group_num * in_width;
const int ingrad_col_stride = group_num;
__shared__ ABDataType shmem_weight[64 * 5 * 5];
extern __shared__ ABDataType shmem_weight[];
constexpr index_t ElementPerFP4 = 16 / sizeof(ABDataType);
for(int index = tid; index < 64 * 5 * 5 / ElementPerFP4; index += BlockSize)
for(int index = tid; index < 64 * filter_size / ElementPerFP4; index += BlockSize)
{
reinterpret_cast<float4*>(shmem_weight)[index] =
reinterpret_cast<const float4*>(weight_ptr)[index];
@@ -247,7 +249,7 @@ __global__ void
batch_id_in_glb_mem * outgrad_batch_stride;
const int base_ingrad_offset =
(grp_idx + local_grp_id) * ingrad_group_stride + batch_id_in_glb_mem * ingrad_batch_stride;
const int base_filter_offset = local_grp_id * 25;
const int base_filter_offset = local_grp_id * filter_size;
for(int h_idx = 0; h_idx < in_height; h_idx += 2)
{
@@ -291,8 +293,10 @@ __global__ void
bool col_in_axis0 = (out_col <= out_col_end_0);
bool col_in_axis1 = (out_col >= out_col_start_1);
const int filter_offset0 = base_filter_offset + filter_row_0 * 5 + filter_col_0;
const int filter_offset1 = base_filter_offset + filter_row_1 * 5 + filter_col_0;
const int filter_offset0 =
base_filter_offset + filter_row_0 * filter_width + filter_col_0;
const int filter_offset1 =
base_filter_offset + filter_row_1 * filter_width + filter_col_0;
sum[0] +=
((row_in_axis0 && col_in_axis0) ? shmem_weight[filter_offset0] * gradOut
@@ -314,15 +318,23 @@ __global__ void
#pragma unroll
for(int j = 0; j < 2; j++)
{
const int output_offset = base_ingrad_offset + (h_idx + i) * ingrad_row_stride +
(w_idx + j) * ingrad_col_stride;
p_gradIn[output_offset] = __float2half(sum[i * 2 + j]);
if((h_idx + i < in_height) && (w_idx + j < in_width))
{
const int output_offset = base_ingrad_offset +
(h_idx + i) * ingrad_row_stride +
(w_idx + j) * ingrad_col_stride;
p_gradIn[output_offset] = __float2half(sum[i * 2 + j]);
}
}
}
}
}
}
// __global__ void kernel_grouped_conv_bwd_data_optimized_small()
// {
// }
} // namespace
// Conv backward data multiple D:
@@ -1114,26 +1126,32 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
GroupPerBlock,
BatchPerBlock,
BlockDim>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(1344 / GroupPerBlock * 16, 1, 1),
dim3(BlockDim),
0,
p_a_grid,
p_b_grid,
p_e_grid,
5,
5,
1,
1,
2,
2,
14,
14,
14,
14,
1344);
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.e_g_n_c_wis_lengths_[0] / GroupPerBlock * 128 / BatchPerBlock,
1,
1),
dim3(BlockDim),
GroupPerBlock * arg.b_g_k_c_xs_lengths_[NDimSpatial + 1] *
arg.b_g_k_c_xs_lengths_[NDimSpatial + 2] * sizeof(ADataType),
p_a_grid,
p_b_grid,
p_e_grid,
arg.b_g_k_c_xs_lengths_[NDimSpatial + 1],
arg.b_g_k_c_xs_lengths_[NDimSpatial + 2],
arg.conv_filter_strides_[0],
arg.conv_filter_strides_[1],
arg.input_left_pads_[0],
arg.input_left_pads_[1],
arg.a_g_n_k_wos_lengths_[NDimSpatial + 1],
arg.a_g_n_k_wos_lengths_[NDimSpatial + 2],
arg.e_g_n_c_wis_lengths_[NDimSpatial + 1],
arg.e_g_n_c_wis_lengths_[NDimSpatial + 2],
arg.a_g_n_k_wos_lengths_[0],
arg.a_g_n_k_wos_lengths_[1],
arg.b_g_k_c_xs_lengths_[NDimSpatial + 1] *
arg.b_g_k_c_xs_lengths_[NDimSpatial + 2]);
// const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
// GridwiseGemm,
// ADataType, // TODO: distiguish A/B datatype