mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
refactor ConstantTensorDescriptor and functional
This commit is contained in:
@@ -8,12 +8,12 @@
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
index_t nrepeat)
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
index_t nrepeat)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
@@ -7,12 +7,12 @@
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_convolution_implicit_gemm_v2_chwn_cyxk_khwn(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
index_t nrepeat)
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
index_t nrepeat)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
@@ -52,7 +52,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
in_nchw_vec(n, c, h, w) =
|
||||
vector_t::Pack(in_nchw(n, 2 * c, h, w), in_nchw(n, 2 * c + 1, h, w));
|
||||
#elif 1
|
||||
in_nchw_vec(n, c, h, w) = vector_t::Pack(in_nchw(n, 4 * c, h, w),
|
||||
in_nchw_vec(n, c, h, w) = vector_t::Pack(in_nchw(n, 4 * c, h, w),
|
||||
in_nchw(n, 4 * c + 1, h, w),
|
||||
in_nchw(n, 4 * c + 2, h, w),
|
||||
in_nchw(n, 4 * c + 3, h, w));
|
||||
@@ -114,37 +114,37 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3, 34x34, 128 thread, fp32, vector = 2
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr index_t NPerThread = 2;
|
||||
constexpr index_t KPerThread = 4;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t NPerThread = 2;
|
||||
constexpr index_t KPerThread = 4;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3, 34x34, 128 thread, int8, vector = 4
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr index_t NPerThread = 1;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t CPerThread = 2;
|
||||
constexpr index_t NPerThread = 1;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t CPerThread = 2;
|
||||
constexpr index_t HoPerThread = 4;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "device_direct_convolution_2_nchw_kcyx_nkhw.hpp"
|
||||
//#include "device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
|
||||
#include "device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp"
|
||||
//#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp"
|
||||
#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
|
||||
|
||||
@@ -48,13 +49,10 @@ struct GeneratorTensor_3
|
||||
#if 0
|
||||
auto f_acc = std::plus<index_t>{};
|
||||
#else
|
||||
auto f_acc = [](auto a, auto b){ return 10*a + b;};
|
||||
auto f_acc = [](auto a, auto b) { return 10 * a + b; };
|
||||
#endif
|
||||
|
||||
return std::accumulate(dims.begin(),
|
||||
dims.end(),
|
||||
index_t(0),
|
||||
f_acc);
|
||||
return std::accumulate(dims.begin(), dims.end(), index_t(0), f_acc);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -376,7 +374,7 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
std::size_t ho = HoPerTile * htile + j;
|
||||
for(int i = 0; i < WoPerTile; ++i)
|
||||
{
|
||||
std::size_t wo = WoPerTile * wtile + i;
|
||||
std::size_t wo = WoPerTile * wtile + i;
|
||||
out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
|
||||
}
|
||||
}
|
||||
@@ -435,13 +433,13 @@ int main(int argc, char* argv[])
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3, 56x56
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
@@ -505,7 +503,7 @@ int main(int argc, char* argv[])
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 512;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
@@ -666,6 +664,8 @@ int main(int argc, char* argv[])
|
||||
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
|
||||
#endif
|
||||
|
||||
@@ -14,5 +14,7 @@ struct Array
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ TData operator[](index_t i) const { return mData[i]; }
|
||||
__host__ __device__ const TData& operator[](index_t i) const { return mData[i]; }
|
||||
|
||||
__host__ __device__ TData& operator[](index_t i) { return mData[i]; }
|
||||
};
|
||||
|
||||
@@ -115,46 +115,27 @@ struct ConstantTensorDescriptor
|
||||
static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t GetDimension() const { return nDim; }
|
||||
__host__ __device__ static constexpr index_t GetDimension() { return nDim; }
|
||||
|
||||
__host__ __device__ constexpr Lengths GetLengths() const { return Lengths{}; }
|
||||
__host__ __device__ static constexpr Lengths GetLengths() { return Lengths{}; }
|
||||
|
||||
__host__ __device__ constexpr Strides GetStrides() const { return Strides{}; }
|
||||
__host__ __device__ static constexpr Strides GetStrides() { return Strides{}; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr index_t GetLength(Number<I>) const
|
||||
__host__ __device__ static constexpr index_t GetLength(Number<I>)
|
||||
{
|
||||
return Lengths{}.Get(Number<I>{});
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr index_t GetStride(Number<I>) const
|
||||
__host__ __device__ static constexpr index_t GetStride(Number<I>)
|
||||
{
|
||||
return Strides{}.Get(Number<I>{});
|
||||
}
|
||||
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct GetElementSize_f
|
||||
__host__ __device__ static constexpr index_t GetElementSize()
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr index_t operator()(IDim idim) const
|
||||
{
|
||||
return Type{}.GetLength(idim);
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ constexpr index_t GetElementSize() const
|
||||
{
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct multiply
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
|
||||
{
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
return static_const_reduce_n<nDim>{}(GetElementSize_f{}, multiply{});
|
||||
return accumulate_on_sequence(Lengths{}, mod_conv::multiplies<index_t>{}, Number<1>{});
|
||||
}
|
||||
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
@@ -168,25 +149,16 @@ struct ConstantTensorDescriptor
|
||||
};
|
||||
|
||||
template <class Align = Number<1>>
|
||||
__host__ __device__ constexpr index_t GetElementSpace(Align align = Align{}) const
|
||||
__host__ __device__ static constexpr index_t GetElementSpace(Align align = Align{})
|
||||
{
|
||||
// c++14 doesn't support constexpr lambdas, has to use this trick instead
|
||||
struct add
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t a, index_t b) const
|
||||
{
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
index_t element_space_unaligned =
|
||||
static_const_reduce_n<nDim>{}(GetElementSpace_f{}, add{}) + 1;
|
||||
static_const_reduce_n<nDim>{}(GetElementSpace_f{}, mod_conv::plus<index_t>{}) + 1;
|
||||
|
||||
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ index_t Get1dIndex(Is... is) const
|
||||
__host__ __device__ static index_t Get1dIndex(Is... is)
|
||||
{
|
||||
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
|
||||
|
||||
@@ -194,7 +166,7 @@ struct ConstantTensorDescriptor
|
||||
|
||||
index_t id = 0;
|
||||
|
||||
static_loop_n<nDim>{}([&](auto IDim) {
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim.Get();
|
||||
#if DEVICE_BACKEND_HIP
|
||||
id += __mul24(multi_id[idim], GetStride(IDim));
|
||||
@@ -206,17 +178,26 @@ struct ConstantTensorDescriptor
|
||||
return id;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto Condense() const
|
||||
__host__ __device__ static Array<index_t, nDim> GetMultiIndex(index_t id)
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
static_for<0, nDim - 1, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim.Get();
|
||||
multi_id[idim] = id / GetStride(IDim);
|
||||
id -= multi_id[idim] * GetStride(IDim);
|
||||
});
|
||||
|
||||
multi_id[nDim - 1] = id / GetStride(Number<nDim - 1>{});
|
||||
|
||||
return multi_id;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Condense()
|
||||
{
|
||||
constexpr auto default_strides = calculate_default_strides(Lengths{});
|
||||
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t NVector>
|
||||
__host__ __device__ constexpr auto Vectorize(Number<IDim>, Number<NVector>) const
|
||||
{
|
||||
assert(false); // not implemented
|
||||
}
|
||||
};
|
||||
|
||||
template <class Lengths>
|
||||
|
||||
@@ -17,6 +17,8 @@ struct Sequence
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
__host__ __device__ index_t operator[](index_t i) const { return mData[i]; }
|
||||
|
||||
// this is ugly, only for nDIm = 4
|
||||
template <index_t I0, index_t I1, index_t I2, index_t I3>
|
||||
__host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const
|
||||
@@ -90,3 +92,21 @@ __host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
|
||||
{
|
||||
return sequence_pop_back(Type{});
|
||||
}
|
||||
|
||||
template <class Seq>
|
||||
struct accumulate_on_sequence_f
|
||||
{
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr index_t operator()(IDim) const
|
||||
{
|
||||
return Seq{}.Get(IDim{});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Seq, class Reduce, index_t I>
|
||||
__host__ __device__ constexpr index_t accumulate_on_sequence(Seq, Reduce, Number<I>)
|
||||
{
|
||||
constexpr index_t a =
|
||||
static_const_reduce_n<Seq::nDim>{}(accumulate_on_sequence_f<Seq>{}, Reduce{});
|
||||
return Reduce{}(a, I);
|
||||
}
|
||||
|
||||
@@ -211,8 +211,7 @@ struct Blockwise2dTensorCopy1
|
||||
|
||||
constexpr index_t read_per_d1 = integer_divide_ceil(L1, DataPerRead);
|
||||
|
||||
constexpr auto ref_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<L0, read_per_d1>{});
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, read_per_d1>{});
|
||||
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
@@ -225,10 +224,8 @@ struct Blockwise2dTensorCopy1
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
const index_t src_index =
|
||||
src_desc.Get1dIndex(did[0], did[1] * DataPerRead);
|
||||
const index_t dst_index =
|
||||
dst_desc.Get1dIndex(did[0], did[1] * DataPerRead);
|
||||
const index_t src_index = src_desc.Get1dIndex(did[0], did[1] * DataPerRead);
|
||||
const index_t dst_index = dst_desc.Get1dIndex(did[0], did[1] * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + src_index));
|
||||
|
||||
@@ -54,8 +54,7 @@ struct Blockwise3dTensorCopy1
|
||||
|
||||
constexpr index_t read_per_d2 = integer_divide_ceil(L2, DataPerRead);
|
||||
|
||||
constexpr auto ref_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<L0, L1, read_per_d2>{});
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, L1, read_per_d2>{});
|
||||
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
@@ -72,10 +71,8 @@ struct Blockwise3dTensorCopy1
|
||||
|
||||
did[2] = is / ref_desc.GetStride(I2);
|
||||
|
||||
const index_t src_index =
|
||||
src_desc.Get1dIndex(did[0], did[1], did[2] * DataPerRead);
|
||||
const index_t dst_index =
|
||||
dst_desc.Get1dIndex(did[0], did[1], did[2] * DataPerRead);
|
||||
const index_t src_index = src_desc.Get1dIndex(did[0], did[1], did[2] * DataPerRead);
|
||||
const index_t dst_index = dst_desc.Get1dIndex(did[0], did[1], did[2] * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + src_index));
|
||||
|
||||
@@ -340,11 +340,10 @@ struct BlockwiseChwnTensorCopyPadded
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
const Float* p_src_tmp =
|
||||
p_src +
|
||||
src_desc.Get1dIndex(c_block_data_begin,
|
||||
(ho_block_data_begin + h_block_pad_low) - h_global_pad_low,
|
||||
(wo_block_data_begin + w_block_pad_low) - w_global_pad_low,
|
||||
n_block_data_begin);
|
||||
p_src + src_desc.Get1dIndex(c_block_data_begin,
|
||||
(ho_block_data_begin + h_block_pad_low) - h_global_pad_low,
|
||||
(wo_block_data_begin + w_block_pad_low) - w_global_pad_low,
|
||||
n_block_data_begin);
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
@@ -494,7 +493,7 @@ struct Blockwise4dTensorCopy3
|
||||
"wrrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
constexpr index_t num_active_thread =
|
||||
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
|
||||
accumulate_on_sequence(ThreadPerDims{}, mod_conv::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
@@ -504,19 +503,18 @@ struct Blockwise4dTensorCopy3
|
||||
}
|
||||
}
|
||||
|
||||
const index_t thread_id_d0 =
|
||||
get_thread_local_1d_id() / (thread_per_d1 * thread_per_d2 * thread_per_d3);
|
||||
index_t itmp = get_thread_local_1d_id() -
|
||||
thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3);
|
||||
const index_t thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3);
|
||||
itmp -= thread_id_d1 * (thread_per_d2 * thread_per_d3);
|
||||
const index_t thread_id_d2 = itmp / thread_per_d3;
|
||||
const index_t thread_id_d3 = itmp - thread_id_d2 * thread_per_d3;
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor(ThreadPerDims{});
|
||||
const auto thread_multi_id = thread_cluster_desc.GetMultiIndex(get_thread_local_1d_id());
|
||||
|
||||
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(
|
||||
thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead);
|
||||
mDstMyThreadOffset = DstDesc{}.Get1dIndex(
|
||||
thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead);
|
||||
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(thread_multi_id[0],
|
||||
thread_multi_id[1],
|
||||
thread_multi_id[2],
|
||||
thread_multi_id[3] * DataPerRead);
|
||||
|
||||
mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_multi_id[0],
|
||||
thread_multi_id[1],
|
||||
thread_multi_id[2],
|
||||
thread_multi_id[3] * DataPerRead);
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
@@ -745,3 +743,113 @@ struct Blockwise4dTensorCopy3
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder>
|
||||
struct Blockwise4dTensorCopyReorder1
|
||||
{
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
|
||||
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcLengths,
|
||||
class SrcSubLengths,
|
||||
class SrcThreadPerDims,
|
||||
class DstFromSrcReorder,
|
||||
index_t DataPerRead,
|
||||
index_t DataPerWrite>
|
||||
struct Blockwise4dTensorCopyReorder3
|
||||
{
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__ Blockwise4dTensorCopyReorder3()
|
||||
{
|
||||
constexpr index_t nDim = SrcDesc{}.GetDimension();
|
||||
|
||||
static_assert(DstDesc{}.GetDimension() == nDim && SrcOpLengths::nDim == nDim &&
|
||||
SrcOpThreadPerDims::nDim == nDim && DstFromSrcReorder::nDim == nDim,
|
||||
"wrong! nDim is not consistent\n");
|
||||
|
||||
// Src
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || SrcDesc{}.GetStride(Number<nDim-1>{}) == 1,
|
||||
"wrong! only support src.stride(nDim-1) == 1 if DataPerRead > 1!\n");
|
||||
|
||||
static_assert(
|
||||
SrcDesc{}.GetStride(Number<nDim-2>{}) % DataPerRead == 0,
|
||||
"wrong! src.stride(nDim-2) should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
static_assert(SrcSubLengths{}.Get(Number<nDim-1>{}) % DataPerRead == 0, "wrong! SrcSubLengths[nDim-1] % DataPerRead != 0\n");
|
||||
|
||||
static_loop<nDim-1>([](auto I){
|
||||
constexpr index_t src_len = SrcLengths{}.Get(I);
|
||||
constexpr index_t src_sub_len = SrcSubLengths{}.Get(I);
|
||||
constexpr index_t thread_per_dim = SrcThreadPerDims{}.Get(I);
|
||||
static_assert(src_len % (src_sub_len * thread_per_dim) == 0,
|
||||
"wrong! cannot evenly divide tensor lengths");
|
||||
});
|
||||
|
||||
constexpr index_t num_active_thread = accumulate_on_sequence(SrcOpThreadPerDims{}, mod_conv::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
static_assert(BlockSize >= num_active_thread,
|
||||
"wrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const auto thread_multi_id = SrcOpThreadPerDims::GetMultiIndex(get_thread_local_1d_id());
|
||||
|
||||
|
||||
const index_t thread_id_d0 =
|
||||
get_thread_local_1d_id() / (thread_per_d1 * thread_per_d2 * thread_per_d3);
|
||||
index_t itmp = get_thread_local_1d_id() -
|
||||
thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3);
|
||||
const index_t thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3);
|
||||
itmp -= thread_id_d1 * (thread_per_d2 * thread_per_d3);
|
||||
const index_t thread_id_d2 = itmp / thread_per_d3;
|
||||
const index_t thread_id_d3 = itmp - thread_id_d2 * thread_per_d3;
|
||||
|
||||
|
||||
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(
|
||||
thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead);
|
||||
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterClipboardSize()
|
||||
{
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -393,9 +393,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
c_thread_sub_mtx,
|
||||
p_c_thread +
|
||||
c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
|
||||
n_repeat * NPerLevel1Cluster),
|
||||
p_c_thread + c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
|
||||
n_repeat * NPerLevel1Cluster),
|
||||
c_block_mtx,
|
||||
p_c_block +
|
||||
c_block_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
|
||||
|
||||
@@ -93,11 +93,10 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
|
||||
Float p_out_thread[out_thread_desc.GetElementSpace()];
|
||||
|
||||
threadwise_4d_tensor_copy(out_block_desc,
|
||||
p_out_block +
|
||||
out_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
k_thread_data_begin,
|
||||
ho_thread_data_begin,
|
||||
wo_thread_data_begin),
|
||||
p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
k_thread_data_begin,
|
||||
ho_thread_data_begin,
|
||||
wo_thread_data_begin),
|
||||
out_thread_desc,
|
||||
p_out_thread,
|
||||
out_thread_desc.GetLengths());
|
||||
@@ -108,11 +107,10 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
|
||||
// threadwise convolution
|
||||
threadwise_direct_convolution_2(
|
||||
in_thread_block_desc,
|
||||
p_in_block +
|
||||
in_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data_begin,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data_begin,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_thread_block_desc,
|
||||
p_wei_block +
|
||||
wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data_begin, 0, 0),
|
||||
@@ -124,11 +122,10 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
|
||||
threadwise_4d_tensor_copy(out_thread_desc,
|
||||
p_out_thread,
|
||||
out_block_desc,
|
||||
p_out_block +
|
||||
out_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
k_thread_data_begin,
|
||||
ho_thread_data_begin,
|
||||
wo_thread_data_begin),
|
||||
p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
k_thread_data_begin,
|
||||
ho_thread_data_begin,
|
||||
wo_thread_data_begin),
|
||||
out_thread_desc.GetLengths());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,26 +1,41 @@
|
||||
#pragma once
|
||||
#include "constant_integral.hip.hpp"
|
||||
|
||||
template <index_t NLoop>
|
||||
struct static_loop_n
|
||||
template <index_t Iter, index_t Remaining, index_t Increment>
|
||||
struct static_for_impl
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
{
|
||||
static_assert(NLoop > 1, "out-of-range");
|
||||
static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
|
||||
static_assert(Increment <= Remaining, "will go out-of-range");
|
||||
|
||||
f(Number<NLoop - 1>{});
|
||||
static_loop_n<NLoop - 1>{}(f);
|
||||
f(Number<Iter>{});
|
||||
static_for_impl<Iter + Increment, Remaining - Increment, Increment>{}(f);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_loop_n<1>
|
||||
template <index_t Iter, index_t Increment>
|
||||
struct static_for_impl<Iter, 0, Increment>
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F) const
|
||||
{
|
||||
// do nothing
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
{
|
||||
f(Number<0>{});
|
||||
static_assert(NBegin < NEnd, "Wrong! we should have NBegin < NEnd");
|
||||
static_assert((NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
static_for_impl<NBegin, NEnd - NBegin, Increment>{}(f);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -54,4 +69,19 @@ __host__ __device__ constexpr auto unpacker(F f)
|
||||
{
|
||||
return [=](auto xs_array){ f(xs...); };
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace mod_conv {
|
||||
template <class T>
|
||||
struct multiplies
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct plus
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
|
||||
};
|
||||
|
||||
} // namespace mod_conv
|
||||
|
||||
@@ -99,8 +99,8 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
|
||||
// tensor view of blockwise input and weight in LDS
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align =
|
||||
mod_conv::max(InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB);
|
||||
constexpr index_t max_align = mod_conv::max(
|
||||
InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{});
|
||||
@@ -135,16 +135,15 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
InBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock*Y*X,KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock*Y*X,KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
@@ -202,9 +201,8 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread);
|
||||
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_chwn_global_desc.Get1dIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
p_in_global + in_chwn_global_desc.Get1dIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
@@ -323,17 +321,16 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_10d_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
threadwise_10d_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global + out_khwn_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -190,9 +190,8 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_chwn_global_desc.Get1dIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
p_in_global + in_chwn_global_desc.Get1dIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
@@ -393,17 +392,16 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_10d_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
threadwise_10d_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global + out_khwn_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -101,8 +101,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align =
|
||||
mod_conv::max(InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB);
|
||||
constexpr index_t max_align = mod_conv::max(
|
||||
InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{});
|
||||
@@ -116,8 +116,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
const auto blockwise_in_copy =
|
||||
#if 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
@@ -125,6 +125,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead>{};
|
||||
#else
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
@@ -150,10 +151,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<KPerBlock>{},
|
||||
Number<wei_c_x_k_block_desc.GetStride(I0)>{});
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_x_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
@@ -187,8 +186,10 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space = in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space = wei_c_x_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space =
|
||||
wei_c_x_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
@@ -213,9 +214,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.Get1dIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
p_in_global + in_c_h_w_n_global_desc.Get1dIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
@@ -227,7 +227,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global_block_offset +
|
||||
in_c_h_w_n_global_desc.Get1dIndex(0, y, 0, 0),
|
||||
in_c_h_w_n_global_desc.Get1dIndex(0, y, 0, 0),
|
||||
p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset +
|
||||
@@ -239,9 +239,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_batch_gemm.Run(p_wei_block + wei_c_x_k_block_desc.Get1dIndex(0, x, 0),
|
||||
p_in_block + in_c_h_w_n_block_desc.Get1dIndex(0, 0, x, 0),
|
||||
p_in_block +
|
||||
in_c_h_w_n_block_desc.Get1dIndex(0, 0, x, 0),
|
||||
p_out_thread);
|
||||
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@@ -321,17 +321,16 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_10d_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
threadwise_10d_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global + out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -365,14 +365,13 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
|
||||
|
||||
constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{});
|
||||
|
||||
threadwise_6d_tensor_copy(
|
||||
out_6d_thread_desc,
|
||||
p_out_thread,
|
||||
out_6d_global_desc,
|
||||
p_out_global +
|
||||
out_kb_global_desc.Get1dIndex(k_thread_data_begin, b_thread_data_begin),
|
||||
out_6d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
threadwise_6d_tensor_copy(out_6d_thread_desc,
|
||||
p_out_thread,
|
||||
out_6d_global_desc,
|
||||
p_out_global + out_kb_global_desc.Get1dIndex(
|
||||
k_thread_data_begin, b_thread_data_begin),
|
||||
out_6d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -113,11 +113,10 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
|
||||
c_block_work_begin += CPerBlock)
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
blockwise_in_copy.Run(p_in_global +
|
||||
in_global_desc.Get1dIndex(n_block_work_begin,
|
||||
c_block_work_begin,
|
||||
hi_block_work_begin,
|
||||
wi_block_work_begin),
|
||||
blockwise_in_copy.Run(p_in_global + in_global_desc.Get1dIndex(n_block_work_begin,
|
||||
c_block_work_begin,
|
||||
hi_block_work_begin,
|
||||
wi_block_work_begin),
|
||||
p_in_block);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
@@ -144,9 +143,9 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
|
||||
}
|
||||
|
||||
// copy output tensor from LDS to device mem
|
||||
blockwise_out_copy.Run(
|
||||
p_out_block,
|
||||
p_out_global +
|
||||
out_global_desc.Get1dIndex(
|
||||
n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin));
|
||||
blockwise_out_copy.Run(p_out_block,
|
||||
p_out_global + out_global_desc.Get1dIndex(n_block_work_begin,
|
||||
k_block_work_begin,
|
||||
ho_block_work_begin,
|
||||
wo_block_work_begin));
|
||||
}
|
||||
|
||||
@@ -175,18 +175,16 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
blockwise_in_copy.Run(p_in_global +
|
||||
in_nchw_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
blockwise_in_copy.Run(p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
p_in_block);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global +
|
||||
wei_kcyx_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_block);
|
||||
blockwise_wei_copy.Run(p_wei_global + wei_kcyx_global_desc.Get1dIndex(
|
||||
k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -196,11 +194,10 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
|
||||
#if 1
|
||||
threadwise_direct_convolution_2(
|
||||
in_nchw_thread_block_desc,
|
||||
p_in_block +
|
||||
in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_thread_block_desc,
|
||||
p_wei_block +
|
||||
wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
@@ -209,11 +206,10 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
|
||||
#elif 0
|
||||
threadwise_direct_convolution_3(
|
||||
in_nchw_thread_block_desc,
|
||||
p_in_block +
|
||||
in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_thread_block_desc,
|
||||
p_wei_block +
|
||||
wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
@@ -228,10 +224,9 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread,
|
||||
out_nkhw_global_desc,
|
||||
p_out_global +
|
||||
out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_nkhw_thread_desc.GetLengths());
|
||||
}
|
||||
|
||||
@@ -198,10 +198,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
p_in_vec_block);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_vec_global +
|
||||
wei_kcyx_vec_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_vec_block);
|
||||
blockwise_wei_copy.Run(p_wei_vec_global + wei_kcyx_vec_global_desc.Get1dIndex(
|
||||
k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_vec_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -211,11 +210,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
#if 1
|
||||
threadwise_direct_convolution_2(
|
||||
in_nchw_vec_thread_block_desc,
|
||||
p_in_vec_block +
|
||||
in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_vec_thread_block_desc,
|
||||
p_wei_vec_block +
|
||||
wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
@@ -224,11 +222,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
#elif 0
|
||||
threadwise_direct_convolution_3(
|
||||
in_nchw_vec_thread_block_desc,
|
||||
p_in_vec_block +
|
||||
in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_vec_thread_block_desc,
|
||||
p_wei_vec_block +
|
||||
wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
@@ -243,10 +240,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread,
|
||||
out_nkhw_global_desc,
|
||||
p_out_global +
|
||||
out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_nkhw_thread_desc.GetLengths());
|
||||
}
|
||||
|
||||
@@ -283,11 +283,10 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
|
||||
out_hkwn_thread_desc,
|
||||
p_out_thread,
|
||||
out_khwn_global_desc,
|
||||
p_out_global +
|
||||
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_hkwn_thread_desc.GetLengths(),
|
||||
reorder_khwn_from_hkwn);
|
||||
}
|
||||
|
||||
@@ -22,7 +22,8 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
|
||||
return os;
|
||||
}
|
||||
|
||||
typedef enum {
|
||||
typedef enum
|
||||
{
|
||||
Half = 0,
|
||||
Float = 1,
|
||||
} DataType_t;
|
||||
|
||||
Reference in New Issue
Block a user