mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
refactoring block copy
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp"
|
||||
//#include "gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp"
|
||||
#include "gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp"
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
@@ -47,6 +47,9 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#endif
|
||||
|
||||
@@ -59,7 +62,7 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
|
||||
for(unsigned i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(
|
||||
#if 0
|
||||
#if 0
|
||||
gridwise_direct_convolution_2_nchw_kcyx_nkhw
|
||||
#else
|
||||
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
|
||||
@@ -78,6 +81,8 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
InBlockCopyDataPerRead,
|
||||
WeiBlockCopyDataPerRead,
|
||||
BlockSize,
|
||||
GridSize>,
|
||||
dim3(GridSize),
|
||||
|
||||
@@ -7,11 +7,11 @@
|
||||
#include "tensor.hpp"
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "conv_common.hip.hpp"
|
||||
#include "device_direct_convolution_1.hpp"
|
||||
//#include "device_direct_convolution_1.hpp"
|
||||
#include "device_direct_convolution_2_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp"
|
||||
#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp"
|
||||
#include "device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp"
|
||||
//#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn.hpp"
|
||||
//#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp"
|
||||
//#include "device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp"
|
||||
|
||||
struct GeneratorTensor_1
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user