mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
update variable to not hard coding
This commit is contained in:
@@ -197,12 +197,14 @@ __global__ void
|
||||
const index_t out_height,
|
||||
const index_t in_width,
|
||||
const index_t in_height,
|
||||
const index_t group_num)
|
||||
const index_t group_num,
|
||||
const index_t whole_batch_num,
|
||||
const index_t filter_size)
|
||||
{
|
||||
constexpr int blockNumPerGroup = 128 / BatchPerBlock;
|
||||
int grp_idx = GroupPerBlock * (blockIdx.x / blockNumPerGroup);
|
||||
const ABDataType* weight_ptr = p_weight + grp_idx * 25;
|
||||
int tid = threadIdx.x;
|
||||
const int blockNumPerGroup = whole_batch_num / BatchPerBlock;
|
||||
int grp_idx = GroupPerBlock * (blockIdx.x / blockNumPerGroup);
|
||||
const ABDataType* weight_ptr = p_weight + grp_idx * filter_size;
|
||||
int tid = threadIdx.x;
|
||||
|
||||
const int filter_height = filter_y;
|
||||
const int filter_width = filter_x;
|
||||
@@ -224,11 +226,11 @@ __global__ void
|
||||
const int ingrad_row_stride = group_num * in_width;
|
||||
const int ingrad_col_stride = group_num;
|
||||
|
||||
__shared__ ABDataType shmem_weight[64 * 5 * 5];
|
||||
extern __shared__ ABDataType shmem_weight[];
|
||||
|
||||
constexpr index_t ElementPerFP4 = 16 / sizeof(ABDataType);
|
||||
|
||||
for(int index = tid; index < 64 * 5 * 5 / ElementPerFP4; index += BlockSize)
|
||||
for(int index = tid; index < 64 * filter_size / ElementPerFP4; index += BlockSize)
|
||||
{
|
||||
reinterpret_cast<float4*>(shmem_weight)[index] =
|
||||
reinterpret_cast<const float4*>(weight_ptr)[index];
|
||||
@@ -247,7 +249,7 @@ __global__ void
|
||||
batch_id_in_glb_mem * outgrad_batch_stride;
|
||||
const int base_ingrad_offset =
|
||||
(grp_idx + local_grp_id) * ingrad_group_stride + batch_id_in_glb_mem * ingrad_batch_stride;
|
||||
const int base_filter_offset = local_grp_id * 25;
|
||||
const int base_filter_offset = local_grp_id * filter_size;
|
||||
|
||||
for(int h_idx = 0; h_idx < in_height; h_idx += 2)
|
||||
{
|
||||
@@ -291,8 +293,10 @@ __global__ void
|
||||
bool col_in_axis0 = (out_col <= out_col_end_0);
|
||||
bool col_in_axis1 = (out_col >= out_col_start_1);
|
||||
|
||||
const int filter_offset0 = base_filter_offset + filter_row_0 * 5 + filter_col_0;
|
||||
const int filter_offset1 = base_filter_offset + filter_row_1 * 5 + filter_col_0;
|
||||
const int filter_offset0 =
|
||||
base_filter_offset + filter_row_0 * filter_width + filter_col_0;
|
||||
const int filter_offset1 =
|
||||
base_filter_offset + filter_row_1 * filter_width + filter_col_0;
|
||||
|
||||
sum[0] +=
|
||||
((row_in_axis0 && col_in_axis0) ? shmem_weight[filter_offset0] * gradOut
|
||||
@@ -314,15 +318,23 @@ __global__ void
|
||||
#pragma unroll
|
||||
for(int j = 0; j < 2; j++)
|
||||
{
|
||||
const int output_offset = base_ingrad_offset + (h_idx + i) * ingrad_row_stride +
|
||||
(w_idx + j) * ingrad_col_stride;
|
||||
p_gradIn[output_offset] = __float2half(sum[i * 2 + j]);
|
||||
if((h_idx + i < in_height) && (w_idx + j < in_width))
|
||||
{
|
||||
const int output_offset = base_ingrad_offset +
|
||||
(h_idx + i) * ingrad_row_stride +
|
||||
(w_idx + j) * ingrad_col_stride;
|
||||
p_gradIn[output_offset] = __float2half(sum[i * 2 + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// __global__ void kernel_grouped_conv_bwd_data_optimized_small()
|
||||
// {
|
||||
|
||||
// }
|
||||
} // namespace
|
||||
|
||||
// Conv backward data multiple D:
|
||||
@@ -1114,26 +1126,32 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
GroupPerBlock,
|
||||
BatchPerBlock,
|
||||
BlockDim>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(1344 / GroupPerBlock * 16, 1, 1),
|
||||
dim3(BlockDim),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_e_grid,
|
||||
5,
|
||||
5,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
14,
|
||||
14,
|
||||
14,
|
||||
14,
|
||||
1344);
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.e_g_n_c_wis_lengths_[0] / GroupPerBlock * 128 / BatchPerBlock,
|
||||
1,
|
||||
1),
|
||||
dim3(BlockDim),
|
||||
GroupPerBlock * arg.b_g_k_c_xs_lengths_[NDimSpatial + 1] *
|
||||
arg.b_g_k_c_xs_lengths_[NDimSpatial + 2] * sizeof(ADataType),
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_e_grid,
|
||||
arg.b_g_k_c_xs_lengths_[NDimSpatial + 1],
|
||||
arg.b_g_k_c_xs_lengths_[NDimSpatial + 2],
|
||||
arg.conv_filter_strides_[0],
|
||||
arg.conv_filter_strides_[1],
|
||||
arg.input_left_pads_[0],
|
||||
arg.input_left_pads_[1],
|
||||
arg.a_g_n_k_wos_lengths_[NDimSpatial + 1],
|
||||
arg.a_g_n_k_wos_lengths_[NDimSpatial + 2],
|
||||
arg.e_g_n_c_wis_lengths_[NDimSpatial + 1],
|
||||
arg.e_g_n_c_wis_lengths_[NDimSpatial + 2],
|
||||
arg.a_g_n_k_wos_lengths_[0],
|
||||
arg.a_g_n_k_wos_lengths_[1],
|
||||
arg.b_g_k_c_xs_lengths_[NDimSpatial + 1] *
|
||||
arg.b_g_k_c_xs_lengths_[NDimSpatial + 2]);
|
||||
// const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
|
||||
// GridwiseGemm,
|
||||
// ADataType, // TODO: distiguish A/B datatype
|
||||
|
||||
Reference in New Issue
Block a user