puting gridwise convolution into its own class

This commit is contained in:
Chao Liu
2019-04-02 19:37:02 -05:00
parent bdbc0eaad1
commit 0b41ca2d9e
3 changed files with 24 additions and 16 deletions

View File

@@ -1,8 +1,9 @@
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
#include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hip.hpp"
//#include "gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
@@ -272,7 +273,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
{
constexpr auto gridwise_conv =
#if 1
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
#else
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
#endif
@@ -301,11 +302,12 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
WeiBlockCopyThreadPerDim0,
WeiBlockCopyThreadPerDim1,
InBlockCopyDataPerRead,
WeiBlockCopyDataPerRead>();
WeiBlockCopyDataPerRead>{};
float time = launch_kernel(gridwise_conv.Run,
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
gridwise_conv,
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));