From 84d9802d30de16795e63a8625098634527c80ae4 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 15 Jan 2019 00:11:30 -0600 Subject: [PATCH] adding implicit gemm --- driver/conv.cu | 2 +- driver/device_implicit_gemm_convolution.cuh | 8 +- src/include/ConstantMatrixDescriptor.cuh | 46 +++++ ...iptor.cuh => ConstantTensorDescriptor.cuh} | 64 ------- src/include/blockwise_direct_convolution.cuh | 2 +- src/include/blockwise_tensor_op.cuh | 2 +- src/include/common.cuh | 70 +++++++- src/include/conv_common.cuh | 2 +- src/include/gemm.cuh | 67 ++++--- src/include/gridwise_direct_convolution_1.cuh | 2 +- src/include/gridwise_direct_convolution_2.cuh | 2 +- .../gridwise_implicit_gemm_convolution.cuh | 164 +++++++++++------- src/include/gridwise_winograd_convolution.cuh | 2 +- src/include/threadwise_direct_convolution.cuh | 2 +- src/include/threadwise_tensor_op.cuh | 2 +- 15 files changed, 268 insertions(+), 169 deletions(-) create mode 100644 src/include/ConstantMatrixDescriptor.cuh rename src/include/{constant_tensor_descriptor.cuh => ConstantTensorDescriptor.cuh} (68%) diff --git a/driver/conv.cu b/driver/conv.cu index 19a89a1e66..ee93414eac 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -4,7 +4,7 @@ #include #include "nvToolsExt.h" #include "tensor.hpp" -#include "constant_tensor_descriptor.cuh" +#include "ConstantTensorDescriptor.cuh" #include "conv_common.cuh" #include "device_direct_convolution_1.cuh" #include "device_direct_convolution_2.cuh" diff --git a/driver/device_implicit_gemm_convolution.cuh b/driver/device_implicit_gemm_convolution.cuh index 384b4c934f..2a529e98c2 100644 --- a/driver/device_implicit_gemm_convolution.cuh +++ b/driver/device_implicit_gemm_convolution.cuh @@ -27,14 +27,14 @@ void device_implicit_gemm_convolution( #if 1 constexpr unsigned NPerBlock = 2; - constexpr unsigned KPerBlock = 128; + constexpr unsigned KPerBlock = 64; constexpr unsigned CPerBlock = 4; constexpr unsigned HoPerBlock = 2; constexpr unsigned WoPerBlock = 32; - constexpr unsigned NPerThread = 2; - constexpr unsigned KPerThread = 8; - constexpr unsigned CPerThread = 2; + constexpr unsigned NPerThread = 2; + constexpr unsigned KPerThread = 8; + constexpr unsigned CPerThread = 2; constexpr unsigned HoPerThread = 1; constexpr unsigned WoPerThread = 4; diff --git a/src/include/ConstantMatrixDescriptor.cuh b/src/include/ConstantMatrixDescriptor.cuh new file mode 100644 index 0000000000..352e0db0d8 --- /dev/null +++ b/src/include/ConstantMatrixDescriptor.cuh @@ -0,0 +1,46 @@ +#pragma once +#include "common.cuh" + +template +struct ConstantMatrixDescriptor +{ + __host__ __device__ ConstantMatrixDescriptor() + { + static_assert(NCol <= RowStride, "wrong! NCol > RowStride!"); + } + + __host__ __device__ constexpr unsigned GetNumberOfRow() const { return NRow; } + + __host__ __device__ constexpr unsigned GetNumberOfColumn() const { return NCol; } + + __host__ __device__ constexpr unsigned GetRowStride() const { return RowStride; } + + __host__ __device__ constexpr unsigned GetElementSize() const { return NRow * NCol; } + + __host__ __device__ constexpr unsigned GetElementSpace() const { return NRow * RowStride; } + + __host__ __device__ unsigned Get1dIndex(unsigned irow, unsigned icol) const + { + return irow * RowStride + icol; + } + + template + __host__ __device__ constexpr auto MakeSubMatrixDescriptor(Number, + Number) const + { + return ConstantMatrixDescriptor{}; + } +}; + +template +__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number, Number) +{ + return ConstantMatrixDescriptor{}; +} + +template +__host__ __device__ constexpr auto + make_ConstantMatrixDescriptor(Number, Number, Number) +{ + return ConstantMatrixDescriptor{}; +} diff --git a/src/include/constant_tensor_descriptor.cuh b/src/include/ConstantTensorDescriptor.cuh similarity index 68% rename from src/include/constant_tensor_descriptor.cuh rename to src/include/ConstantTensorDescriptor.cuh index a04ba63dd3..6030d51de6 100644 --- a/src/include/constant_tensor_descriptor.cuh +++ b/src/include/ConstantTensorDescriptor.cuh @@ -1,70 +1,6 @@ #pragma once #include "common.cuh" -template -struct Constant -{ - static const T mValue = N; -}; - -template -using Number = Constant; - -template -struct Sequence -{ - static constexpr unsigned nDim = sizeof...(Is); - - const unsigned mData[nDim] = {Is...}; - - template - __host__ __device__ constexpr unsigned Get(Number) const - { - return mData[I]; - } - - template - __host__ __device__ constexpr auto Reorder(Number, Number) const - { - constexpr unsigned IR0 = Get(Number{}); - constexpr unsigned IR1 = Get(Number{}); - - return Sequence{}; - } - - template - __host__ __device__ constexpr auto Reorder(Number, Number, Number) const - { - constexpr unsigned IR0 = Get(Number{}); - constexpr unsigned IR1 = Get(Number{}); - constexpr unsigned IR2 = Get(Number{}); - - return Sequence{}; - } - - template - __host__ __device__ constexpr auto Reorder(Number, Number, Number, Number) const - { - constexpr unsigned IR0 = Get(Number{}); - constexpr unsigned IR1 = Get(Number{}); - constexpr unsigned IR2 = Get(Number{}); - constexpr unsigned IR3 = Get(Number{}); - - return Sequence{}; - } - - template - __host__ __device__ constexpr auto Reorder(Sequence) const - { - constexpr unsigned IR0 = Get(Number{}); - constexpr unsigned IR1 = Get(Number{}); - constexpr unsigned IR2 = Get(Number{}); - constexpr unsigned IR3 = Get(Number{}); - - return Sequence{}; - } -}; - template struct ConstantTensorDescriptor { diff --git a/src/include/blockwise_direct_convolution.cuh b/src/include/blockwise_direct_convolution.cuh index 85bf1bddae..1cf261dc00 100644 --- a/src/include/blockwise_direct_convolution.cuh +++ b/src/include/blockwise_direct_convolution.cuh @@ -1,5 +1,5 @@ #pragma once -#include "constant_tensor_descriptor.cuh" +#include "ConstantTensorDescriptor.cuh" #include "threadwise_tensor_op.cuh" #include "threadwise_direct_convolution.cuh" diff --git a/src/include/blockwise_tensor_op.cuh b/src/include/blockwise_tensor_op.cuh index 0649f6ba77..3635235770 100644 --- a/src/include/blockwise_tensor_op.cuh +++ b/src/include/blockwise_tensor_op.cuh @@ -1,5 +1,5 @@ #pragma once -#include "constant_tensor_descriptor.cuh" +#include "ConstantTensorDescriptor.cuh" template __device__ void diff --git a/src/include/common.cuh b/src/include/common.cuh index dc8e2c5fed..0939c13227 100644 --- a/src/include/common.cuh +++ b/src/include/common.cuh @@ -12,4 +12,72 @@ struct is_same static const bool value = true; }; -__device__ unsigned get_thread_local_id() { return threadIdx.x; } +__device__ unsigned get_thread_local_1d_id() { return threadIdx.x; } + +__device__ unsigned get_block_1d_id() { return blockIdx.x; } + +template +struct Constant +{ + static const T mValue = N; + + __host__ __device__ constexpr T Get() const { return mValue; } +}; + +template +using Number = Constant; + +template +struct Sequence +{ + static constexpr unsigned nDim = sizeof...(Is); + + const unsigned mData[nDim] = {Is...}; + + template + __host__ __device__ constexpr unsigned Get(Number) const + { + return mData[I]; + } + + template + __host__ __device__ constexpr auto Reorder(Number, Number) const + { + constexpr unsigned IR0 = Get(Number{}); + constexpr unsigned IR1 = Get(Number{}); + + return Sequence{}; + } + + template + __host__ __device__ constexpr auto Reorder(Number, Number, Number) const + { + constexpr unsigned IR0 = Get(Number{}); + constexpr unsigned IR1 = Get(Number{}); + constexpr unsigned IR2 = Get(Number{}); + + return Sequence{}; + } + + template + __host__ __device__ constexpr auto Reorder(Number, Number, Number, Number) const + { + constexpr unsigned IR0 = Get(Number{}); + constexpr unsigned IR1 = Get(Number{}); + constexpr unsigned IR2 = Get(Number{}); + constexpr unsigned IR3 = Get(Number{}); + + return Sequence{}; + } + + template + __host__ __device__ constexpr auto Reorder(Sequence) const + { + constexpr unsigned IR0 = Get(Number{}); + constexpr unsigned IR1 = Get(Number{}); + constexpr unsigned IR2 = Get(Number{}); + constexpr unsigned IR3 = Get(Number{}); + + return Sequence{}; + } +}; diff --git a/src/include/conv_common.cuh b/src/include/conv_common.cuh index 81f0b167af..f1e2b2c9f9 100644 --- a/src/include/conv_common.cuh +++ b/src/include/conv_common.cuh @@ -1,5 +1,5 @@ #pragma once -#include "constant_tensor_descriptor.cuh" +#include "ConstantTensorDescriptor.cuh" // this is ugly, only for 4d template diff --git a/src/include/gemm.cuh b/src/include/gemm.cuh index d7a8800431..6e08b0b1b1 100644 --- a/src/include/gemm.cuh +++ b/src/include/gemm.cuh @@ -1,12 +1,13 @@ #pragma once template __device__ void threadwise_gemm(ThreadMatrixA, @@ -26,41 +27,51 @@ __device__ void threadwise_gemm(ThreadMatrixA, template struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c { + unsigned mMyThreadOffsetA = 0; + unsigned mMyThreadOffsetB = 0; + struct MatrixIndex { unsigned batch_begin; - unsigned block_row_begin; - unsigned block_col_begin; + unsigned row_begin; + unsigned col_begin; }; __device__ blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c() { static_assert(ThreadMatrixStrideC > 0, "wrong! ThreadMatrixStrideC == 0!"); - constexpr auto a_block = BlockMatrixA{}; - constexpr auto b_block = BlockMatrixB{}; +#if 0 + constexpr auto a_block_desc = BlockMatrixA{}; + constexpr auto b_block_desc = BlockMatrixB{}; - constexpr auto a_thread = ThreadMatrixA{}; - constexpr auto b_thread = ThreadMatrixB{}; - constexpr auto c_thread = ThreadMatrixC{}; + constexpr unsigned a_thread_row = (!TransA) ? MPerThread : KPerThread; + constexpr unsigned a_thread_col = (!TransA) ? KPerThread : MPerThread; + constexpr unsigned b_thread_row = (!TransB) ? KPerThread : NPerThread; + constexpr unsigned b_thread_col = (!TransB) ? NPerThread : KPerThread; - constexpr unsigned m_block = (!TransA) ? a_block.NRow() : a_block.NCol(); - constexpr unsigned n_block = (!TransB) ? b_block.NCol() : b_block.NRow(); + constexpr auto a_thread_desc = ConstantMatrixDescriptor{}; + constexpr auto b_thread_desc = ConstantMatrixDescriptor{}; + constexpr auto c_thread_desc = ConstantMatrixDescriptor{}; - constexpr unsigned m_thread = (!TransA) ? a_thread.NRow() : a_thread.NCol(); - constexpr unsigned n_thread = (!TransB) ? b_thread.NCol() : b_thread.NRow(); + constexpr unsigned m_block = (!TransA) ? a_block_desc.NRow() : a_block_desc.NCol(); + constexpr unsigned n_block = (!TransB) ? b_block_desc.NCol() : b_block_desc.NRow(); + + constexpr unsigned m_thread = (!TransA) ? a_thread_desc.NRow() : a_thread_desc.NCol(); + constexpr unsigned n_thread = (!TransB) ? b_thread_desc.NCol() : b_thread_desc.NRow(); constexpr unsigned num_threads_per_row = (m_block + m_thread - 1) / m_thread; constexpr unsigned num_threads_per_col = (n_block + n_thread - 1) / n_thread; @@ -72,12 +83,17 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c const auto mtx_c_idnex = CalculateThreadMatrixCIndex(get_thread_local_id()); - mMyThreadOffsetA = xxx; - mMyThreadoffSetB = xxx; + // mMyThreadOffsetA = xxx; + // mMyThreadoffSetB = xxx; +#else + mMyThreadOffsetA = 0; + mMyThreadOffsetB = 0; +#endif } __device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const { +#if 0 constexpr auto a_block = BlockMatrixA{}; constexpr auto b_block = BlockMatrixB{}; constexpr auto c_block = BlockMatrixC{}; @@ -104,6 +120,9 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c return MatrixIndex{ batch_begin, thread_matrix_row_id * m_thread, thread_matrix_col_id * n_thread}; +#else + return MatrixIndex{0, 0, 0}; +#endif } template @@ -111,8 +130,4 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c { // do something } - - private: - unsigned mMyThreadOffsetA = 0; - unsigned mMyThreadOffsetB = 0; -} +}; diff --git a/src/include/gridwise_direct_convolution_1.cuh b/src/include/gridwise_direct_convolution_1.cuh index 4dfc6dfebc..91f8e59984 100644 --- a/src/include/gridwise_direct_convolution_1.cuh +++ b/src/include/gridwise_direct_convolution_1.cuh @@ -1,5 +1,5 @@ #pragma once -#include "constant_tensor_descriptor.cuh" +#include "ConstantTensorDescriptor.cuh" #include "blockwise_tensor_op.cuh" #include "blockwise_direct_convolution.cuh" diff --git a/src/include/gridwise_direct_convolution_2.cuh b/src/include/gridwise_direct_convolution_2.cuh index e0ef90c0aa..f61139f116 100644 --- a/src/include/gridwise_direct_convolution_2.cuh +++ b/src/include/gridwise_direct_convolution_2.cuh @@ -1,5 +1,5 @@ #pragma once -#include "constant_tensor_descriptor.cuh" +#include "ConstantTensorDescriptor.cuh" #include "blockwise_tensor_op.cuh" #include "blockwise_direct_convolution.cuh" #include "threadwise_tensor_op.cuh" diff --git a/src/include/gridwise_implicit_gemm_convolution.cuh b/src/include/gridwise_implicit_gemm_convolution.cuh index c870b8db55..6b2bb0fdd7 100644 --- a/src/include/gridwise_implicit_gemm_convolution.cuh +++ b/src/include/gridwise_implicit_gemm_convolution.cuh @@ -1,7 +1,10 @@ #pragma once -#include "constant_tensor_descriptor.cuh" +#include "common.cuh" +#include "ConstantTensorDescriptor.cuh" +#include "ConstantMatrixDescriptor.cuh" #include "blockwise_tensor_op.cuh" #include "threadwise_tensor_op.cuh" +#include "gemm.cuh" template {}); + // divide block work: NCHW + constexpr unsigned NBlockWork = + (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; + constexpr unsigned KBlockWork = + (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; + constexpr unsigned HBlockWork = + (out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; + constexpr unsigned WBlockWork = + (out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; + unsigned itmp = get_block_1d_id(); + const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); + itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); + const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork); + itmp -= k_block_work_id * (HBlockWork * WBlockWork); + const unsigned h_block_work_id = itmp / WBlockWork; + const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork; + + const unsigned n_block_data_begin = n_block_work_id * NPerBlock; + const unsigned k_block_data_begin = k_block_work_id * KPerBlock; + const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; + const unsigned wo_block_data_begin = w_block_work_id * HoPerBlock; + + const unsigned hi_block_data_begin = ho_block_data_begin; + const unsigned wi_block_data_begin = wo_block_data_begin; + + // tensor view of blockwise input and weight in LDS constexpr auto wei_srck_block_desc = make_ConstantTensorDescriptor(Sequence{}); - // matrix view of blockwise input and weight in LDS - constexpr auto in_cxhwn_block_mtx_desc = make_ConstantMatrixDescriptor( - Number, Number); + constexpr auto in_chwn_block_desc = + make_ConstantTensorDescriptor(Sequence{}); - constexpr auto wei_srcxk_block_mtx_desc = - make_ConstantMatrixDescriptor(Number, Number); + // tensor view of threadwise output in register + constexpr auto out_hkwn_thread_desc = + make_ConstantTensorDescriptor(Sequence{}); + + // a series of blockwise batched GEMM + // C_matrix += transpose(A_matrix) * B_matrix + // A_matrix and B_matrix saved in LDS, C_matrix saved in register + // A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K] + // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] + // C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N] + const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + + const auto b_cxwn_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, + Number{}, + Number{}); // constexpr doesn't compile + + const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + + auto f_accum = [](auto& c, auto& ab) { c += ab; }; + + const auto blockwise_batch_gemm = + blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c{}; // LDS constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); @@ -66,87 +128,59 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, __shared__ Float p_in_block[in_block_size]; __shared__ Float p_wei_block[wei_block_size]; - // a series of batched GEMM - // blockwise batched GEMM, C_matrix += transpose(A_matrix) * B_matrix - // A_matrix and B_matrix saved in LDS, c_matrix saved in register - // A_matrix[C,K] is a sub-matrix of wei_matrix[S*R*C,K] - // B_matrix[C,Wo*N] is a sub-matrix of in_matrix[C,Hi*Wi*N] - // C_matrix[K,Wo*N] is a sub-matrix of out_matrix[Ho*K,Wo*N] - constexpr auto a_block_mtx_desc = - wei_srcxk_block_mtx_desc.MakeSubMatrixDescriptor(Number{}, Number{}); - - constexpr auto b_block_mtx_desc = in_cxhwn_block_mtx_desc.MakeSubMatrixDescriptor( - Number{}, Number{}); - - auto f_accum = (auto& c, auto& v) { c += v; }; - - const auto blockwise_batch_gemm = - blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c{}; - - // tensor view of threadwise output in register - constexpr auto out_hkwn_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - // register Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()]; // set threadwise output tensor to 0 threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread); - for(unsigned c_block_data_begin = 0; c_block_data_begin < in_global_desc.GetLength(I1); + for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1); c_block_data_begin += CPerBlock, __syncthreads()) { // input: global mem to LDS, // convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N] constexpr auto reorder_nchw2chwn = Sequence<3, 0, 1, 2>{}; - blockwise_4d_tensor_copy_reorder(in_nchw_global_desc, - p_in_global, - in_chwn_block_desc, - p_in_block, - in_chwn_block_desc, - reorder_nchw2chwn); + blockwise_4d_tensor_copy_reorder( + in_nchw_global_desc, + p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin, + c_block_data_begin, + hi_block_data_begin, + wi_block_data_begin), + in_chwn_block_desc, + p_in_block, + in_chwn_block_desc, + reorder_nchw2chwn); // weight: global mem to LDS, // convert 4d-tensor wei[K,C,S,R] to matrix wei_matrix[S*R*C,K] constexpr auto reorder_kcsr2srck = Sequence<3, 2, 0, 1>{}; - blockwise_4d_tensor_copy_reorder(wei_csrk_global_desc, - p_wei_global, - wei_csrk_block_desc, - p_wei_block, - wei_csrk_block_desc, - reorder_kcsr2csrk); + blockwise_4d_tensor_copy_reorder( + wei_kcsr_global_desc, + p_wei_global + + wei_kcsr_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0), + wei_srck_block_desc, + p_wei_block, + wei_srck_block_desc, + reorder_kcsr2srck); __syncthreads(); - // loop over filter point + // a series of batched GEMM for(unsigned s = 0; s < S; ++s) { for(unsigned r = 0; r < R; ++r) { - blockwise_batch_gemm.run( - p_wei_block + wei_srcxk_block_mtx_desc.Get1dIndex(xxxxx, xxxx), - p_in_block + in_cxhwn_block_mtx_desc.Get1dIndex(xxxx, xxxx), - p_out_thread); + blockwise_batch_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), + p_in_block + in_chwn_block_desc.Get1dIndex(0, 0, r, 0), + p_out_thread); } } } const auto matrix_c_index = - blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_id()); + blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; const unsigned k_thread_data_begin = matrix_c_index.col_begin; @@ -160,10 +194,10 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, out_hkwn_thread_desc, p_out_thread, out_nkhw_global_desc, - p_out_global + out_nkhw_global_desc.GetIndex(n_block_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), + p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin, + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin), out_hkwn_thread_desc, reorder_hkwn2nkhw); } diff --git a/src/include/gridwise_winograd_convolution.cuh b/src/include/gridwise_winograd_convolution.cuh index 2797844bdf..3d5b739263 100644 --- a/src/include/gridwise_winograd_convolution.cuh +++ b/src/include/gridwise_winograd_convolution.cuh @@ -1,5 +1,5 @@ #pragma once -#include "constant_tensor_descriptor.cuh" +#include "ConstantTensorDescriptor.cuh" #include "blockwise_winograd_transform.cuh" #include "threadwise_winograd_transform.cuh" diff --git a/src/include/threadwise_direct_convolution.cuh b/src/include/threadwise_direct_convolution.cuh index 12c6ebac07..dfb6901fa7 100644 --- a/src/include/threadwise_direct_convolution.cuh +++ b/src/include/threadwise_direct_convolution.cuh @@ -1,5 +1,5 @@ #pragma once -#include "constant_tensor_descriptor.cuh" +#include "ConstantTensorDescriptor.cuh" // optimized for scenario if p_in, p_wei, p_out are in register template diff --git a/src/include/threadwise_tensor_op.cuh b/src/include/threadwise_tensor_op.cuh index 7aeed4a764..fcb769ddc1 100644 --- a/src/include/threadwise_tensor_op.cuh +++ b/src/include/threadwise_tensor_op.cuh @@ -1,5 +1,5 @@ #pragma once -#include "constant_tensor_descriptor.cuh" +#include "ConstantTensorDescriptor.cuh" template __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __restrict__ p, F f)