mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
refactoring ConstantTensorDescriptor
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp"
|
||||
#include "gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp"
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
|
||||
@@ -57,27 +58,33 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
for(unsigned i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw<T,
|
||||
InDesc,
|
||||
WeiDesc,
|
||||
OutDesc,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
BlockSize,
|
||||
GridSize>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_device_buf.GetDeviceBuffer()));
|
||||
float time = launch_kernel(
|
||||
#if 0
|
||||
gridwise_direct_convolution_2_nchw_kcyx_nkhw
|
||||
#else
|
||||
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw
|
||||
#endif
|
||||
<T,
|
||||
InDesc,
|
||||
WeiDesc,
|
||||
OutDesc,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
BlockSize,
|
||||
GridSize>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms\n", time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
|
||||
@@ -211,7 +211,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
|
||||
for(unsigned i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(
|
||||
#if 1
|
||||
#if 0
|
||||
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
|
||||
#else
|
||||
gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer
|
||||
|
||||
18
src/include/Array.hip.hpp
Normal file
18
src/include/Array.hip.hpp
Normal file
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
template <class TData, unsigned NSize>
|
||||
struct Array
|
||||
{
|
||||
using Type = Array<TData, NSize>;
|
||||
|
||||
static constexpr unsigned nSize = NSize;
|
||||
|
||||
unsigned mData[nSize];
|
||||
|
||||
template <class... Xs>
|
||||
__host__ __device__ Array(Xs... xs) : mData({static_cast<TData>(xs)...})
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ TData operator[](unsigned i) const { return mData[i]; }
|
||||
};
|
||||
@@ -65,8 +65,8 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0
|
||||
template <class Lengths, class Strides>
|
||||
struct ConstantTensorDescriptor
|
||||
{
|
||||
using Type = ConstantTensorDescriptor<Lengths, Strides>;
|
||||
static constexpr unsigned nDim = Lengths::nDim;
|
||||
using NDimConstant = Number<nDim>;
|
||||
|
||||
__host__ __device__ constexpr ConstantTensorDescriptor()
|
||||
{
|
||||
@@ -91,293 +91,70 @@ struct ConstantTensorDescriptor
|
||||
return Strides{}.Get(Number<I>{});
|
||||
}
|
||||
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct GetElementSize_f
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr unsigned operator()(IDim idim) const
|
||||
{
|
||||
return Type{}.GetLength(idim);
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ constexpr unsigned GetElementSize() const
|
||||
{
|
||||
static_assert(nDim >= 2 && nDim <= 8, "nDim");
|
||||
|
||||
if(nDim == 2)
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct multiply
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
|
||||
{
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
return GetLength(I0) * GetLength(I1);
|
||||
}
|
||||
else if(nDim == 3)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2);
|
||||
}
|
||||
else if(nDim == 4)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3);
|
||||
}
|
||||
else if(nDim == 5)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4);
|
||||
}
|
||||
else if(nDim == 6)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) *
|
||||
GetLength(I5);
|
||||
}
|
||||
else if(nDim == 7)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) *
|
||||
GetLength(I5) * GetLength(I6);
|
||||
}
|
||||
else if(nDim == 8)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) *
|
||||
GetLength(I5) * GetLength(I6) * GetLength(I7);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
return static_const_reduce_n<nDim>{}(GetElementSize_f{}, multiply{});
|
||||
}
|
||||
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct GetElementSpace_f
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr unsigned operator()(IDim idim) const
|
||||
{
|
||||
return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Align = Number<1>>
|
||||
__host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const
|
||||
{
|
||||
static_assert(nDim >= 2 && nDim <= 8, "nDim");
|
||||
|
||||
constexpr unsigned align_size = align.Get();
|
||||
|
||||
if(nDim == 2)
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct add
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
__host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const
|
||||
{
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
align_size;
|
||||
}
|
||||
else if(nDim == 3)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + align_size;
|
||||
}
|
||||
else if(nDim == 4)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
|
||||
align_size;
|
||||
}
|
||||
else if(nDim == 5)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
|
||||
(GetLength(I4) - 1) * GetStride(I4) + align_size;
|
||||
}
|
||||
else if(nDim == 6)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
|
||||
(GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) +
|
||||
align_size;
|
||||
}
|
||||
else if(nDim == 7)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
|
||||
(GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) +
|
||||
(GetLength(I6) - 1) * GetStride(I6) + align_size;
|
||||
}
|
||||
else if(nDim == 8)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
|
||||
(GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) +
|
||||
(GetLength(I6) - 1) * GetStride(I6) + (GetLength(I7) - 1) * GetStride(I7) +
|
||||
align_size;
|
||||
}
|
||||
return static_const_reduce_n<nDim>{}(GetElementSpace_f{}, add{}) + align.Get();
|
||||
}
|
||||
|
||||
// this is ugly, only for 2d
|
||||
__host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1) const
|
||||
template <class... Is>
|
||||
__host__ __device__ unsigned Get1dIndex(Is... is) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
|
||||
|
||||
static_assert(nDim == 2, "nDim is not 2");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1);
|
||||
}
|
||||
const auto multi_id = Array<unsigned, nDim>(is...);
|
||||
|
||||
// this is ugly, only for 3d
|
||||
__host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1, unsigned i2) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
unsigned id = 0;
|
||||
|
||||
static_assert(nDim == 3, "nDim is not 3");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2);
|
||||
}
|
||||
static_loop_n<nDim>{}([&](auto IDim) {
|
||||
constexpr unsigned idim = IDim.Get();
|
||||
id += multi_id[idim] * GetStride(IDim);
|
||||
});
|
||||
|
||||
// this is ugly, only for 4d
|
||||
__host__ __device__ unsigned
|
||||
Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(nDim == 4, "nDim is not 4");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3);
|
||||
}
|
||||
|
||||
// this is ugly, only for 5d
|
||||
__host__ __device__ unsigned
|
||||
Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
static_assert(nDim == 5, "nDim is not 5");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
|
||||
i4 * GetStride(I4);
|
||||
}
|
||||
|
||||
// this is ugly, only for 6d
|
||||
__host__ __device__ unsigned
|
||||
Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4, unsigned i5) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
static_assert(nDim == 6, "nDim is not 6");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
|
||||
i4 * GetStride(I4) + i5 * GetStride(I5);
|
||||
}
|
||||
|
||||
// this is ugly, only for 7d
|
||||
__host__ __device__ unsigned Get1dIndex(unsigned i0,
|
||||
unsigned i1,
|
||||
unsigned i2,
|
||||
unsigned i3,
|
||||
unsigned i4,
|
||||
unsigned i5,
|
||||
unsigned i6) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
static_assert(nDim == 7, "nDim is not 7");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
|
||||
i4 * GetStride(I4) + i5 * GetStride(I5) + i6 * GetStride(I6);
|
||||
}
|
||||
|
||||
// this is ugly, only for 8d
|
||||
__host__ __device__ unsigned Get1dIndex(unsigned i0,
|
||||
unsigned i1,
|
||||
unsigned i2,
|
||||
unsigned i3,
|
||||
unsigned i4,
|
||||
unsigned i5,
|
||||
unsigned i6,
|
||||
unsigned i7) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
static_assert(nDim == 8, "nDim is not 8");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
|
||||
i4 * GetStride(I4) + i5 * GetStride(I5) + i6 * GetStride(I6) + i7 * GetStride(I7);
|
||||
return id;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto Condense() const
|
||||
@@ -385,6 +162,12 @@ struct ConstantTensorDescriptor
|
||||
constexpr auto default_strides = calculate_default_strides(Lengths{});
|
||||
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
|
||||
}
|
||||
|
||||
template <unsigned IDim, unsigned NVector>
|
||||
__host__ __device__ constexpr auto Vectorize(Number<IDim>, Number<NVector>) const
|
||||
{
|
||||
assert(false); // not implemented
|
||||
}
|
||||
};
|
||||
|
||||
template <class Lengths>
|
||||
|
||||
92
src/include/Sequence.hip.hpp
Normal file
92
src/include/Sequence.hip.hpp
Normal file
@@ -0,0 +1,92 @@
|
||||
#pragma once
|
||||
#include "constant_integral.hip.hpp"
|
||||
#include "functional.hip.hpp"
|
||||
|
||||
template <unsigned... Is>
|
||||
struct Sequence
|
||||
{
|
||||
using Type = Sequence<Is...>;
|
||||
|
||||
static constexpr unsigned nDim = sizeof...(Is);
|
||||
|
||||
const unsigned mData[nDim] = {Is...};
|
||||
|
||||
template <unsigned I>
|
||||
__host__ __device__ constexpr unsigned Get(Number<I>) const
|
||||
{
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
// this is ugly, only for nDIm = 4
|
||||
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
|
||||
__host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const
|
||||
{
|
||||
static_assert(nDim == 4, "nDim != 4");
|
||||
|
||||
constexpr auto old_sequence = Type{};
|
||||
|
||||
constexpr unsigned NR0 = old_sequence.mData[I0];
|
||||
constexpr unsigned NR1 = old_sequence.mData[I1];
|
||||
constexpr unsigned NR2 = old_sequence.mData[I2];
|
||||
constexpr unsigned NR3 = old_sequence.mData[I3];
|
||||
|
||||
return Sequence<NR0, NR1, NR2, NR3>{};
|
||||
}
|
||||
|
||||
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
|
||||
__host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence<I0, I1, I2, I3>) const
|
||||
{
|
||||
// don't know how to implement this
|
||||
printf("Sequence::ReorderByPutOldToNew not implemented");
|
||||
assert(false);
|
||||
}
|
||||
|
||||
template <unsigned I>
|
||||
__host__ __device__ constexpr auto PushBack(Number<I>) const
|
||||
{
|
||||
return Sequence<Is..., I>{};
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto PopBack() const;
|
||||
|
||||
template <class F>
|
||||
__host__ __device__ constexpr auto Transform(F f) const
|
||||
{
|
||||
return Sequence<f(Is)...>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <unsigned... Is, unsigned I>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
|
||||
{
|
||||
static_assert(sizeof...(Is) >= 1, "empty Sequence!");
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
|
||||
template <class F, unsigned... Xs, unsigned... Ys>
|
||||
__host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequence<Ys...>, F f)
|
||||
{
|
||||
static_assert(Sequence<Xs...>::nDim == Sequence<Ys...>::nDim, "Dim not the same");
|
||||
|
||||
return Sequence<f(Xs, Ys)...>{};
|
||||
}
|
||||
|
||||
template <unsigned... Xs, unsigned... Ys>
|
||||
__host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
struct add
|
||||
{
|
||||
__host__ __device__ constexpr unsigned operator()(unsigned x, unsigned y) const
|
||||
{
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
return sequence_sequence_op(Sequence<Xs...>{}, Sequence<Ys...>{}, add{});
|
||||
}
|
||||
|
||||
template <unsigned... Is>
|
||||
__host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
|
||||
{
|
||||
return sequence_pop_back(Type{});
|
||||
}
|
||||
@@ -1,4 +1,8 @@
|
||||
#pragma once
|
||||
#include "constant_integral.hip.hpp"
|
||||
#include "Sequence.hip.hpp"
|
||||
#include "Array.hip.hpp"
|
||||
#include "functional.hip.hpp"
|
||||
|
||||
__device__ unsigned get_thread_local_1d_id() { return threadIdx.x; }
|
||||
|
||||
@@ -91,54 +95,6 @@ struct vector_type<half, 8>
|
||||
};
|
||||
#endif
|
||||
|
||||
template <class T, T N>
|
||||
struct integral_constant
|
||||
{
|
||||
static const T value = N;
|
||||
|
||||
__host__ __device__ constexpr T Get() const { return value; }
|
||||
};
|
||||
|
||||
template <unsigned N>
|
||||
using Number = integral_constant<unsigned, N>;
|
||||
|
||||
template <unsigned... Is>
|
||||
struct Sequence
|
||||
{
|
||||
using Type = Sequence<Is...>;
|
||||
|
||||
static constexpr unsigned nDim = sizeof...(Is);
|
||||
|
||||
const unsigned mData[nDim] = {Is...};
|
||||
|
||||
template <unsigned I>
|
||||
__host__ __device__ constexpr unsigned Get(Number<I>) const
|
||||
{
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
|
||||
__host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const
|
||||
{
|
||||
constexpr auto old_sequence = Type{};
|
||||
|
||||
constexpr unsigned NR0 = old_sequence.mData[I0];
|
||||
constexpr unsigned NR1 = old_sequence.mData[I1];
|
||||
constexpr unsigned NR2 = old_sequence.mData[I2];
|
||||
constexpr unsigned NR3 = old_sequence.mData[I3];
|
||||
|
||||
return Sequence<NR0, NR1, NR2, NR3>{};
|
||||
}
|
||||
|
||||
template <unsigned I0, unsigned I1, unsigned I2, unsigned I3>
|
||||
__host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence<I0, I1, I2, I3>) const
|
||||
{
|
||||
// don't know how to implement this
|
||||
printf("Sequence::ReorderByPutOldToNew not implemented");
|
||||
assert(false);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T max(T a, T b)
|
||||
{
|
||||
|
||||
12
src/include/constant_integral.hip.hpp
Normal file
12
src/include/constant_integral.hip.hpp
Normal file
@@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
template <class T, T N>
|
||||
struct integral_constant
|
||||
{
|
||||
static const T value = N;
|
||||
|
||||
__host__ __device__ constexpr T Get() const { return value; }
|
||||
};
|
||||
|
||||
template <unsigned N>
|
||||
using Number = integral_constant<unsigned, N>;
|
||||
49
src/include/functional.hip.hpp
Normal file
49
src/include/functional.hip.hpp
Normal file
@@ -0,0 +1,49 @@
|
||||
#pragma once
|
||||
#include "constant_integral.hip.hpp"
|
||||
|
||||
template <unsigned NLoop>
|
||||
struct static_loop_n
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
{
|
||||
static_assert(NLoop > 1, "out-of-range");
|
||||
|
||||
f(Number<NLoop - 1>{});
|
||||
static_loop_n<NLoop - 1>{}(f);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_loop_n<1>
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
{
|
||||
f(Number<0>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <unsigned NLoop>
|
||||
struct static_const_reduce_n
|
||||
{
|
||||
template <class F, class Reduce>
|
||||
__host__ __device__ constexpr auto operator()(F f, Reduce r) const
|
||||
{
|
||||
static_assert(NLoop > 1, "out-of-range");
|
||||
|
||||
constexpr auto a = f(Number<NLoop - 1>{});
|
||||
auto b = static_const_reduce_n<NLoop - 1>{}(f, r); // cannot use constexpr here, weird
|
||||
return r(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_const_reduce_n<1>
|
||||
{
|
||||
template <class F, class Reduce>
|
||||
__host__ __device__ constexpr auto operator()(F f, Reduce) const
|
||||
{
|
||||
return f(Number<0>{});
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,195 @@
|
||||
#pragma once
|
||||
#include "common.hip.hpp"
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "blockwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_direct_convolution.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "threadwise_direct_convolution.hip.hpp"
|
||||
|
||||
template <class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
unsigned NPerBlock,
|
||||
unsigned KPerBlock,
|
||||
unsigned CPerBlock,
|
||||
unsigned HoPerBlock,
|
||||
unsigned WoPerBlock,
|
||||
unsigned NPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned HoPerThread,
|
||||
unsigned WoPerThread,
|
||||
unsigned BlockSize,
|
||||
unsigned GridSize>
|
||||
__global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr unsigned Y = wei_global_desc.GetLength(I2);
|
||||
constexpr unsigned X = wei_global_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
constexpr auto in_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{});
|
||||
|
||||
constexpr auto wei_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, Y, X>{});
|
||||
|
||||
// shared mem
|
||||
constexpr unsigned in_block_size = in_block_desc.GetElementSpace();
|
||||
constexpr unsigned wei_block_size = wei_block_desc.GetElementSpace();
|
||||
|
||||
__shared__ Float p_in_block[in_block_size];
|
||||
__shared__ Float p_wei_block[wei_block_size];
|
||||
|
||||
// threadwise tensors
|
||||
constexpr unsigned HiPerThread = HoPerThread + Y - 1;
|
||||
constexpr unsigned WiPerThread = WoPerThread + X - 1;
|
||||
|
||||
constexpr auto in_thread_block_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{}, in_block_desc.GetStrides());
|
||||
|
||||
constexpr auto wei_thread_block_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, CPerThread, Y, X>{}, wei_block_desc.GetStrides());
|
||||
|
||||
constexpr auto out_thread_desc = get_convolution_output_default_4d_tensor_descriptor(
|
||||
in_thread_block_desc, wei_thread_block_desc);
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_thread_desc.GetElementSpace()];
|
||||
|
||||
// divide block work
|
||||
constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
|
||||
constexpr unsigned KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
|
||||
constexpr unsigned HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr unsigned WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
|
||||
|
||||
const unsigned block_id = blockIdx.x;
|
||||
|
||||
unsigned itmp = block_id;
|
||||
const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
|
||||
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
|
||||
const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork);
|
||||
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
|
||||
const unsigned h_block_work_id = itmp / WBlockWork;
|
||||
const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork;
|
||||
|
||||
const unsigned n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
|
||||
const unsigned hi_block_data_begin = ho_block_data_begin; // minus padding
|
||||
const unsigned wi_block_data_begin = wo_block_data_begin; // minus padding
|
||||
|
||||
// divide thread work
|
||||
constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
|
||||
constexpr unsigned KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
|
||||
constexpr unsigned HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
|
||||
constexpr unsigned WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
|
||||
|
||||
const unsigned thread_id = threadIdx.x;
|
||||
|
||||
itmp = thread_id;
|
||||
const unsigned n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
|
||||
itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork);
|
||||
const unsigned k_thread_work_id = itmp / (HThreadWork * WThreadWork);
|
||||
itmp -= k_thread_work_id * (HThreadWork * WThreadWork);
|
||||
const unsigned h_thread_work_id = itmp / WThreadWork;
|
||||
const unsigned w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
|
||||
|
||||
const unsigned n_thread_data_begin = n_thread_work_id * NPerThread;
|
||||
const unsigned k_thread_data_begin = k_thread_work_id * KPerThread;
|
||||
const unsigned ho_thread_data_begin = h_thread_work_id * HoPerThread;
|
||||
const unsigned wo_thread_data_begin = w_thread_work_id * WoPerThread;
|
||||
|
||||
const unsigned hi_thread_data_begin = ho_thread_data_begin;
|
||||
const unsigned wi_thread_data_begin = wo_thread_data_begin;
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_global_desc),
|
||||
decltype(in_block_desc),
|
||||
decltype(in_block_desc.GetLengths())>{};
|
||||
|
||||
constexpr auto blockwise_wei_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_global_desc),
|
||||
decltype(wei_block_desc),
|
||||
decltype(wei_block_desc.GetLengths())>{};
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_thread_desc, p_out_thread);
|
||||
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < in_global_desc.GetLength(I1);
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
blockwise_in_copy.Run(p_in_global + in_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
p_in_block);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global + wei_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
|
||||
{
|
||||
// threadwise convolution
|
||||
#if 1
|
||||
threadwise_direct_convolution_2(
|
||||
in_thread_block_desc,
|
||||
p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_thread_block_desc,
|
||||
p_wei_block + wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
out_thread_desc,
|
||||
p_out_thread);
|
||||
#elif 0
|
||||
threadwise_direct_convolution_3(
|
||||
in_thread_block_desc,
|
||||
p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_thread_block_desc,
|
||||
p_wei_block + wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
out_thread_desc,
|
||||
p_out_thread);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// copy output tensor from register to global mem
|
||||
threadwise_4d_tensor_copy(
|
||||
out_thread_desc,
|
||||
p_out_thread,
|
||||
out_global_desc,
|
||||
p_out_global + out_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_thread_desc.GetLengths());
|
||||
}
|
||||
Reference in New Issue
Block a user