mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
tweak
This commit is contained in:
@@ -233,8 +233,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
|
||||
|
||||
const Float* p_wei_block_on_global = p_wei_global;
|
||||
|
||||
for(index_t e_block_data_begin = 0; e_block_data_begin < E; e_block_data_begin += EPerBlock)
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block);
|
||||
@@ -246,8 +244,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_in_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true);
|
||||
blockwise_wei_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true);
|
||||
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
@@ -304,8 +302,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
{
|
||||
threadwise_out_copy.Run(p_out_thread, p_out_global);
|
||||
|
||||
threadwise_out_copy.MoveSrcSlicingWindow({0, 0, GemmNPerThreadSubC}, true);
|
||||
threadwise_out_copy.MoveDstSlicingWindow({0, 0, B1}, true);
|
||||
threadwise_out_copy.MoveSrcSlicingWindow(Sequence<0, 0, GemmNPerThreadSubC>{},
|
||||
True);
|
||||
threadwise_out_copy.MoveDstSlicingWindow(Sequence<0, 0, B1>{}, True);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,8 +233,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
|
||||
|
||||
const Float* p_wei_block_on_global = p_wei_global;
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block_double);
|
||||
@@ -263,15 +261,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true);
|
||||
blockwise_wei_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true);
|
||||
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
|
||||
p_wei_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global, p_wei_register_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
@@ -288,14 +285,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
// even iteration
|
||||
blockwise_in_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true);
|
||||
blockwise_wei_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true);
|
||||
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global, p_wei_register_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
@@ -369,8 +366,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
{
|
||||
threadwise_out_copy.Run(p_out_thread, p_out_global);
|
||||
|
||||
threadwise_out_copy.MoveSrcSlicingWindow({0, 0, GemmNPerThreadSubC}, true);
|
||||
threadwise_out_copy.MoveDstSlicingWindow({0, 0, B1}, true);
|
||||
threadwise_out_copy.MoveSrcSlicingWindow(Sequence<0, 0, GemmNPerThreadSubC>{},
|
||||
True);
|
||||
threadwise_out_copy.MoveDstSlicingWindow(Sequence<0, 0, B1>{}, True);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -447,17 +447,19 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
}
|
||||
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
__device__ void
|
||||
MoveSrcSlicingWindow(T step_sizes,
|
||||
integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
mThreadwiseLoad.MoveSrcSlicingWindow(step_sizes,
|
||||
integral_constant<bool, PositiveDirection>{});
|
||||
mThreadwiseLoad.MoveSrcSlicingWindow(step_sizes, positive_direction);
|
||||
}
|
||||
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void MoveDstSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
__device__ void
|
||||
MoveDstSlicingWindow(T step_sizes,
|
||||
integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
mThreadwiseLoad.MoveDstSlicingWindow(step_sizes,
|
||||
integral_constant<bool, PositiveDirection>{});
|
||||
mThreadwiseLoad.MoveDstSlicingWindow(step_sizes, positive_direction);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@@ -227,9 +227,9 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void MoveDstSlicingWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
static_if<PositiveDirection>([&](auto) { mDstSliceOrigin += step_sizes; }).Else([&](auto) {
|
||||
mDstSliceOrigin -= step_sizes;
|
||||
});
|
||||
static_if<PositiveDirection>{}([&](auto) {
|
||||
mDstSliceOrigin += step_sizes;
|
||||
}).Else([&](auto) { mDstSliceOrigin -= step_sizes; });
|
||||
}
|
||||
|
||||
// private:
|
||||
|
||||
@@ -132,7 +132,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
constexpr auto gridwise_conv =
|
||||
#if 0
|
||||
#if 1
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
#else
|
||||
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
|
||||
@@ -379,7 +379,7 @@ int main(int argc, char* argv[])
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(
|
||||
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
|
||||
Reference in New Issue
Block a user