refactoring ConstantTensorDescriptor

[ROCm/composable_kernel commit: a0584426ff]
This commit is contained in:
Chao Liu
2019-03-17 03:22:41 -05:00
parent 0485704cb3
commit 6fd0910da8
9 changed files with 452 additions and 340 deletions

View File

@@ -2,6 +2,7 @@
#include <unistd.h>
#include "device.hpp"
#include "gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp"
#include "gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
@@ -57,27 +58,33 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
for(unsigned i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw<T,
InDesc,
WeiDesc,
OutDesc,
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
CPerThread,
HoPerThread,
WoPerThread,
BlockSize,
GridSize>,
dim3(GridSize),
dim3(BlockSize),
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_device_buf.GetDeviceBuffer()));
float time = launch_kernel(
#if 0
gridwise_direct_convolution_2_nchw_kcyx_nkhw
#else
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#endif
<T,
InDesc,
WeiDesc,
OutDesc,
NPerBlock,
KPerBlock,
CPerBlock,
HoPerBlock,
WoPerBlock,
NPerThread,
KPerThread,
CPerThread,
HoPerThread,
WoPerThread,
BlockSize,
GridSize>,
dim3(GridSize),
dim3(BlockSize),
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms\n", time);
usleep(std::min(time * 1000, float(10000)));

View File

@@ -211,7 +211,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
for(unsigned i = 0; i < nrepeat; ++i)
{
float time = launch_kernel(
#if 1
#if 0
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
#else
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer