refactor direct

This commit is contained in:
Chao Liu
2018-11-25 01:10:11 -06:00
parent 8732ea04fb
commit 24d2f034fa
14 changed files with 253 additions and 1291 deletions

View File

@@ -16,10 +16,10 @@ void device_direct_convolution_2(
wei_device_buf.ToDevice(wei.mData.data());
out_device_buf.ToDevice(out.mData.data());
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{};
@@ -36,11 +36,6 @@ void device_direct_convolution_2(
constexpr unsigned KPerThread = 4;
constexpr unsigned CPerThread = 2;
constexpr unsigned NBlockOpLen0 = 1;
constexpr unsigned NBlockOpLen1 = 1;
constexpr unsigned NBlockOpLen2 = 4;
constexpr unsigned NBlockOpLen3 = 32;
constexpr unsigned BlockSize = 128;
constexpr unsigned GridSize = (out_desc.GetLength(I0) / NPerBlock) *
@@ -73,10 +68,6 @@ void device_direct_convolution_2(
NPerThread,
KPerThread,
CPerThread,
NBlockOpLen0,
NBlockOpLen1,
NBlockOpLen2,
NBlockOpLen3,
BlockSize,
GridSize>
<<<grid_dim, block_dim>>>(InDesc{},