mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-26 08:00:13 +00:00
refactoring ConstantTensorDescriptor
[ROCm/composable_kernel commit: a0584426ff]
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp"
|
||||
#include "gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp"
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
|
||||
@@ -57,27 +58,33 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
for(unsigned i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw<T,
|
||||
InDesc,
|
||||
WeiDesc,
|
||||
OutDesc,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
BlockSize,
|
||||
GridSize>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_device_buf.GetDeviceBuffer()));
|
||||
float time = launch_kernel(
|
||||
#if 0
|
||||
gridwise_direct_convolution_2_nchw_kcyx_nkhw
|
||||
#else
|
||||
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
|
||||
#endif
|
||||
<T,
|
||||
InDesc,
|
||||
WeiDesc,
|
||||
OutDesc,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
BlockSize,
|
||||
GridSize>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms\n", time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
|
||||
@@ -211,7 +211,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
|
||||
for(unsigned i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(
|
||||
#if 1
|
||||
#if 0
|
||||
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
|
||||
#else
|
||||
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
|
||||
|
||||
Reference in New Issue
Block a user