mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
adding implicit gemm v3
This commit is contained in:
95
src/include/ConstantMergedTensorDescriptor.hip.hpp
Normal file
95
src/include/ConstantMergedTensorDescriptor.hip.hpp
Normal file
@@ -0,0 +1,95 @@
|
||||
#pragma once
|
||||
#include "common.hip.hpp"
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
|
||||
// TensorDesc: ConstantTensorDescriptor<...>
|
||||
// MergedDimRanges: Sequence<FirstMergedDim, LastMergedDim>
|
||||
template <class TensorDesc, class... MergedDimRanges>
|
||||
struct ConstantMergedTensorDescriptor
|
||||
{
|
||||
static constexpr index_t nOriginalDim = GetNumOfOriginalDimension();
|
||||
static constexpr index_t nDim = GetNumOfDimension();
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ constexpr ConstantMergedTensorDescriptor()
|
||||
{
|
||||
constexpr auto merged_dim_ranges = std::make_tuple(MergedDimRanges{}...);
|
||||
|
||||
static_for<0, sizeof...(MergedDimRanges), 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
constexpr auto merged_dim_range = std::get<i>(merged_dim_ranges);
|
||||
|
||||
static_assert(merged_dim_range.GetSize() == 2,
|
||||
"wrong! should specify first and last dimension to be merged");
|
||||
static_assert(merged_dim_range.Get(Number<0>{}) < GetNumOfUnmergedDimension(),
|
||||
"wrong!");
|
||||
static_assert(merged_dim_range.Get(Number<1>{}) < GetNumOfUnmergedDimension(),
|
||||
"wrong!");
|
||||
static_assert(merged_dim_range.Get(Number<0>{}) <= merged_dim_range.Get(Number<1>{}),
|
||||
"wrong!");
|
||||
});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfOriginalDimension()
|
||||
{
|
||||
return TensorDesc::GetNumOfDimension();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfDimension()
|
||||
{
|
||||
constexpr auto merged_dim_ranges = std::make_tuple(MergedDimRanges...);
|
||||
|
||||
struct f_calculate_num_of_lost_dim
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(auto I) const
|
||||
{
|
||||
constexpr index_t i = I.Get();
|
||||
constexpr auto merged_dim_range = std::get<i>(merged_dim_ranges);
|
||||
|
||||
return merged_dim_range.Get(Number<1>{}) - merged_dim_range.Get(Number<0>{});
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t num_lost_dim = static_const_reduce_n<sizeof...(MergedDimRanges)>{}(
|
||||
f_calculate_num_of_lost_dim, mod_conv::plus<index_t>{});
|
||||
|
||||
return TensorDesc::GetNumOfDimension() - num_lost_dim;
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool IsMergedDimension(Number<IDim>)
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool GetLength(Number<IDim>)
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool GetStride(Number<IDim>)
|
||||
{
|
||||
static_assert(!IsMergedDimension(Number<IDim>{}, "wrong! A merged dimension does not have uniform stride")
|
||||
// not implemented
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ auto MultiIndex2OriginalMultiIndex(Is... is) const
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ auto OriginalMultiIndex2MultiIndex(Is... is) const
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
};
|
||||
|
||||
template <class TensorDesc, class... MergedDimRanges>
|
||||
constexpr auto make_ConstantMergedTensorDescriptor(TensorDesc, MergedDimRanges...)
|
||||
{
|
||||
return ConstantMergedTensorDescriptor<TensorDesc, MergedDimRanges...>{};
|
||||
}
|
||||
@@ -65,7 +65,7 @@ struct ConstantTensorDescriptor
|
||||
static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetDimension() { return nDim; }
|
||||
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }
|
||||
|
||||
__host__ __device__ static constexpr Lengths GetLengths() { return Lengths{}; }
|
||||
|
||||
@@ -160,11 +160,51 @@ struct ConstantTensorDescriptor
|
||||
return multi_id;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Condense()
|
||||
__host__ __device__ static constexpr auto Pack()
|
||||
{
|
||||
constexpr auto default_strides = calculate_default_strides(Lengths{});
|
||||
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
|
||||
}
|
||||
|
||||
template <index_t IDims...>
|
||||
__host__ __device__ static constexpr auto Extract(Number<IDims>... /*extracted_dims...*/)
|
||||
{
|
||||
static_assert(sizeof...(IDims) <= GetNumOfDimension(), "wrong!");
|
||||
|
||||
constexpr auto extracted_lengths = Sequence<Lengths{}.Get(Number<IDims>{})...>{};
|
||||
constexpr auto extracted_strides = Sequence<Strides{}.Get(Number<IDims>{})...>{};
|
||||
|
||||
return make_ConstantTensorDescriptor(extracted_lenghts, extracted_strides);
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t SliceLen>
|
||||
__host__ __device__ static constexpr auto Slice(Number<IDim>, Number<SliceLen>)
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... FoldLengths>
|
||||
__host__ device__ static constexpr auto Fold(Number<IDim>, Sequence<FoldLengths...>)
|
||||
{
|
||||
// not implemented
|
||||
// need to check the Length dimension to be folded is dividable by FoldLengths
|
||||
}
|
||||
|
||||
template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
|
||||
__host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>)
|
||||
{
|
||||
// not implemented
|
||||
// need to check the dimensions to be unfold are packed, otherwise, Unfold is not permitted
|
||||
}
|
||||
|
||||
template <index_t... IRs>
|
||||
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(sizeof...(IRs) == GetNumberOfDimension(), "wrong! dimension is wrong");
|
||||
constexpr auto map_new2old = Sequence<IRs...>{};
|
||||
return make_ConstantTensorDescriptor(Lengths{}.ReorderGivenNew2Old(map_new2old),
|
||||
Strides{}.ReorderGivenNew2Old(map_new2old));
|
||||
}
|
||||
};
|
||||
|
||||
template <class Lengths>
|
||||
@@ -191,7 +231,7 @@ template <class TDesc>
|
||||
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
{
|
||||
constexpr auto desc = TDesc{};
|
||||
constexpr index_t ndim = desc.GetDimension();
|
||||
constexpr index_t ndim = desc.GetNumOfDimension();
|
||||
|
||||
static_assert(ndim >= 2 && ndim <= 10, "wrong!");
|
||||
|
||||
@@ -202,7 +242,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
|
||||
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetStride(I0),
|
||||
@@ -216,7 +256,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
@@ -233,7 +273,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
@@ -253,7 +293,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
@@ -276,7 +316,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
@@ -302,7 +342,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
@@ -331,7 +371,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
@@ -364,7 +404,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
|
||||
"%u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
@@ -400,7 +440,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetNumOfDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
|
||||
@@ -59,10 +59,22 @@ struct Sequence
|
||||
|
||||
__host__ __device__ constexpr auto PopBack() const;
|
||||
|
||||
template <class F>
|
||||
__host__ __device__ constexpr auto Transform(F f) const
|
||||
template <index_t I, index_t X>
|
||||
__host__ __device__ constexpr auto Insert(Number<I>, Number<X>) const
|
||||
{
|
||||
return Sequence<f(Is)...>{};
|
||||
index_t data[mSize + 1];
|
||||
|
||||
static_for<0, I, 1>{}([&](auto Iter) {
|
||||
constexpr index_t iter = Iter.Get();
|
||||
data[iter] = mData[iter];
|
||||
});
|
||||
|
||||
data[I] = X;
|
||||
|
||||
static_for<I, nSize, 1>{}([&](auto Iter) {
|
||||
constexpr index_t iter = Iter.Get();
|
||||
data[iter + 1] = mData[iter];
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
|
||||
template <index_t BlockSize, class Float, class DstDesc, class F>
|
||||
__device__ void
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
template <index_t BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
class ThreadMatrixC,
|
||||
index_t MPerThreadSubC,
|
||||
index_t NPerThreadSubC,
|
||||
index_t MLevel0Cluster,
|
||||
@@ -35,51 +34,35 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
|
||||
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
|
||||
static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(),
|
||||
"wrong! K dimension not consistent\n");
|
||||
|
||||
constexpr index_t M = a_block_mtx.NCol(); // A is transposed
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
|
||||
constexpr index_t N = BlockMatrixB::NCol();
|
||||
constexpr index_t K = BlockMatrixA::NRow();
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
static_assert(M % (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster) == 0 &&
|
||||
N % (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster) == 0,
|
||||
"wrong! Cannot evenly divide work among\n");
|
||||
|
||||
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
|
||||
"wrong! Cannot evenly divide thread work among repeat \n");
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
|
||||
"wrong! Cannot evenly divide work among repeat\n");
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = M / MRepeat;
|
||||
constexpr index_t NPerLevel1Cluster = N / NRepeat;
|
||||
|
||||
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
|
||||
(NPerLevel1Cluster % NLevel1Cluster == 0),
|
||||
"wrong! Cannot evenly divide work among Level1Cluster\n");
|
||||
|
||||
constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
|
||||
|
||||
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
|
||||
(NPerLevel0Cluster % NLevel0Cluster == 0),
|
||||
"wrong! Cannot evenly divide work among Level0Cluster\n");
|
||||
|
||||
static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
|
||||
(NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
|
||||
"wrong! thread work size is wrong\n");
|
||||
static_assert(ThreadMatrixC::GetLengths() == GetThreadMatrixCLengths,
|
||||
"wrong! ThreadMatrixC lengths is wrong");
|
||||
|
||||
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
mMyThreadOffsetA = a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row);
|
||||
mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);
|
||||
mMyThreadOffsetA = BlockMatrixA::Get1dIndex(0, c_thread_mtx_index.row);
|
||||
mMyThreadOffsetB = BlockMatrixB::Get1dIndex(0, c_thread_mtx_index.col);
|
||||
}
|
||||
|
||||
__device__ static auto GetThreadMatrixCLengths()
|
||||
{
|
||||
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
|
||||
constexpr index_t N = BlockMatrixB::NCol();
|
||||
|
||||
constexpr index_t MRepeat = M / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster);
|
||||
constexpr index_t NRepeat = N / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster);
|
||||
|
||||
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
|
||||
}
|
||||
|
||||
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
|
||||
@@ -101,7 +84,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
|
||||
}
|
||||
|
||||
// this should be optimized away if input is known
|
||||
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
|
||||
index_t n_in_c)
|
||||
{
|
||||
|
||||
55
src/include/blockwise_merged_tensor_slice_op.hip.hpp
Normal file
55
src/include/blockwise_merged_tensor_slice_op.hip.hpp
Normal file
@@ -0,0 +1,55 @@
|
||||
#pragma once
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
|
||||
// slice a merged tensor, reorder and copy it into a normal tensor
|
||||
// src: a merged tensor,
|
||||
// dst: a normal tensor
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SliceLengths,
|
||||
class SubLengths,
|
||||
class ClusterLengths,
|
||||
class ThreadArrangeOrder,
|
||||
class SrcAccessOrder,
|
||||
class DstAccessOrder>
|
||||
struct BlockwiseTensorSliceCopy_generic_v1
|
||||
{
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__ BlockwiseTensorSliceCopy_generic_v1(Array<index_t, nDim> src_block_multi_id_offset,
|
||||
Array<index_t, nDim> dst_block_multi_id_offset)
|
||||
{
|
||||
// only support SrcSubLengths.GetLength() == 1 on merged dimension, for now
|
||||
// check SrcDataPerRead should be 1, if last dimension is a merged dimension
|
||||
|
||||
// check NDim consistent
|
||||
|
||||
// calculate mSrcMyThreadOffset
|
||||
// calculate mDstMyThreadOffset
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterClipboardSize() {}
|
||||
|
||||
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
Float p_clipboard[GetRegisterClipboardSize()];
|
||||
|
||||
RunLoadRegisterClipboard(p_src, p_clipboard);
|
||||
RunStoreRegisterClipboard(p_clipboard, p_dst);
|
||||
}
|
||||
};
|
||||
@@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
@@ -12,14 +12,16 @@ template <index_t BlockSize,
|
||||
class MapThreadCluster2SrcCluster,
|
||||
index_t SrcDataPerRead,
|
||||
index_t DstDataPerWrite>
|
||||
struct BlockwiseNdTensorCopyReorder_v3
|
||||
struct BlockwiseTensorSliceReorderCopy_v3
|
||||
{
|
||||
static constexpr index_t nDim = SrcLengths::GetSize();
|
||||
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__ BlockwiseNdTensorCopyReorder_v3()
|
||||
__device__
|
||||
BlockwiseTensorSliceReorderCopy_v3(Array<index_t, nDim> src_block_data_multi_id_begin,
|
||||
Array<index_t, nDim> dst_block_data_multi_id_begin)
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
@@ -43,8 +45,9 @@ struct BlockwiseNdTensorCopyReorder_v3
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float for now!\n");
|
||||
|
||||
// sanity check: nDim
|
||||
static_assert(SrcDesc::GetDimension() == nDim && DstDesc::GetDimension() == nDim &&
|
||||
SrcLengths::GetSize() == nDim && SrcSubLengths::GetSize() == nDim &&
|
||||
static_assert(SrcDesc::GetNumOfDimension() == nDim &&
|
||||
DstDesc::GetNumOfDimension() == nDim && SrcLengths::GetSize() == nDim &&
|
||||
SrcSubLengths::GetSize() == nDim &&
|
||||
SrcClusterLengths::GetSize() == nDim && MapDst2Src::GetSize() == nDim &&
|
||||
MapThreadCluster2SrcCluster::GetSize() == nDim,
|
||||
"wrong! nDim is not consistent\n");
|
||||
@@ -112,17 +115,17 @@ struct BlockwiseNdTensorCopyReorder_v3
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr auto I = decltype(IDim){};
|
||||
constexpr index_t i = I.Get();
|
||||
// compiler: will it really compute index here, or be associated with Get1dIndex and
|
||||
// compiler: will it really compute index here, or be merged with Get1dIndex and
|
||||
// optimized away???
|
||||
src_data_multi_id[i] *= src_sub_lengths.Get(I);
|
||||
});
|
||||
|
||||
// compiler: will it really compute index here, or be associated with Get1dIndex and
|
||||
// compiler: will it really compute index here, or be merged with Get1dIndex and
|
||||
// optimized away???
|
||||
const auto dst_data_multi_id = reorder_array_given_new2old(src_data_multi_id, map_dst2src);
|
||||
|
||||
mSrcMyThreadOffset = src_desc.Get1dIndex(src_data_multi_id);
|
||||
mDstMyThreadOffset = dst_desc.Get1dIndex(dst_data_multi_id);
|
||||
mSrcMyThreadOffset = src_desc.Get1dIndex(src_data_multi_id + src_block_data_multi_id_begin);
|
||||
mDstMyThreadOffset = dst_desc.Get1dIndex(dst_data_multi_id + dst_block_data_multi_id_begin);
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterClipboardSize()
|
||||
@@ -176,12 +179,12 @@ struct BlockwiseNdTensorCopyReorder_v3
|
||||
constexpr index_t clipboard_offset =
|
||||
thread_tensor_desc.Get1dIndex(clipboard_data_multi_id);
|
||||
|
||||
threadwise_nd_tensor_copy(SrcDesc{},
|
||||
p_src + src_offset + mSrcMyThreadOffset,
|
||||
thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
thread_sub_tensor_lengths,
|
||||
Number<SrcDataPerRead>{});
|
||||
threadwise_tensor_slice_copy(SrcDesc{},
|
||||
p_src + src_offset + mSrcMyThreadOffset,
|
||||
thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
thread_sub_tensor_lengths,
|
||||
Number<SrcDataPerRead>{});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -222,22 +225,22 @@ struct BlockwiseNdTensorCopyReorder_v3
|
||||
|
||||
// write in the order of dst
|
||||
#if 1
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
DstDesc{},
|
||||
p_dst + dst_offset +
|
||||
mDstMyThreadOffset,
|
||||
thread_sub_tensor_lengths,
|
||||
MapDst2Src{});
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
DstDesc{},
|
||||
p_dst + dst_offset +
|
||||
mDstMyThreadOffset,
|
||||
thread_sub_tensor_lengths,
|
||||
MapDst2Src{});
|
||||
#else
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v3(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
DstDesc{},
|
||||
p_dst + dst_offset +
|
||||
mDstMyThreadOffset,
|
||||
thread_sub_tensor_lengths,
|
||||
MapDst2Src{},
|
||||
Number<DstDataPerWrite>{});
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
DstDesc{},
|
||||
p_dst + dst_offset +
|
||||
mDstMyThreadOffset,
|
||||
thread_sub_tensor_lengths,
|
||||
MapDst2Src{},
|
||||
Number<DstDataPerWrite>{});
|
||||
#endif
|
||||
});
|
||||
}
|
||||
@@ -14,8 +14,8 @@ __host__ __device__ constexpr auto get_convolution_output_default_4d_tensor_desc
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(in_desc.GetDimension() == 4, "input nDim is not 4");
|
||||
static_assert(wei_desc.GetDimension() == 4, "weight nDim is not 4");
|
||||
static_assert(in_desc.GetNumOfDimension() == 4, "input nDim is not 4");
|
||||
static_assert(wei_desc.GetNumOfDimension() == 4, "weight nDim is not 4");
|
||||
static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
|
||||
"input & weight dimension not consistent");
|
||||
|
||||
@@ -45,8 +45,8 @@ __host__ __device__ constexpr auto get_convolution_with_padding_output_default_4
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(in_desc.GetDimension() == 4, "input nDim is not 4");
|
||||
static_assert(wei_desc.GetDimension() == 4, "weight nDim is not 4");
|
||||
static_assert(in_desc.GetNumOfDimension() == 4, "input nDim is not 4");
|
||||
static_assert(wei_desc.GetNumOfDimension() == 4, "weight nDim is not 4");
|
||||
static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
|
||||
"input & weight dimension not consistent");
|
||||
|
||||
|
||||
@@ -1,76 +1,6 @@
|
||||
#pragma once
|
||||
#include "constant_integral.hip.hpp"
|
||||
|
||||
template <index_t Iter, index_t Remaining, index_t Increment>
|
||||
struct static_for_impl
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
{
|
||||
static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
|
||||
static_assert(Increment <= Remaining, "will go out-of-range");
|
||||
|
||||
f(Number<Iter>{});
|
||||
static_for_impl<Iter + Increment, Remaining - Increment, Increment>{}(f);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t Iter, index_t Increment>
|
||||
struct static_for_impl<Iter, 0, Increment>
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F) const
|
||||
{
|
||||
// no work left, just return
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
{
|
||||
static_assert(NBegin < NEnd, "Wrong! we should have NBegin < NEnd");
|
||||
static_assert((NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
static_for_impl<NBegin, NEnd - NBegin, Increment>{}(f);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NLoop>
|
||||
struct static_const_reduce_n
|
||||
{
|
||||
template <class F, class Reduce>
|
||||
__host__ __device__ constexpr auto operator()(F f, Reduce r) const
|
||||
{
|
||||
static_assert(NLoop > 1, "out-of-range");
|
||||
|
||||
constexpr auto a = f(Number<NLoop - 1>{});
|
||||
auto b = static_const_reduce_n<NLoop - 1>{}(f, r); // TODO: cannot use constexpr here, weird
|
||||
return r(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_const_reduce_n<1>
|
||||
{
|
||||
template <class F, class Reduce>
|
||||
__host__ __device__ constexpr auto operator()(F f, Reduce) const
|
||||
{
|
||||
return f(Number<0>{});
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
template<class F>
|
||||
__host__ __device__ constexpr auto unpacker(F f)
|
||||
{
|
||||
return [=](auto xs_array){ f(xs...); };
|
||||
}
|
||||
#endif
|
||||
|
||||
struct forwarder
|
||||
{
|
||||
template <typename T>
|
||||
@@ -132,3 +62,76 @@ struct static_if<false>
|
||||
return Type{};
|
||||
}
|
||||
};
|
||||
template <index_t Iter, index_t Remaining, index_t Increment>
|
||||
struct static_for_impl
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
{
|
||||
static_assert(Remaining % Increment == 0, "wrong! Remaining % Increment != 0");
|
||||
static_assert(Increment <= Remaining, "will go out-of-range");
|
||||
|
||||
f(Number<Iter>{});
|
||||
static_for_impl<Iter + Increment, Remaining - Increment, Increment>{}(f);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t Iter, index_t Increment>
|
||||
struct static_for_impl<Iter, 0, Increment>
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F) const
|
||||
{
|
||||
// no work left, just return
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// F signature: F(Number<I>)
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
{
|
||||
static_assert((NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
|
||||
static_if < NBegin<End>{}([&](auto forwarder) {
|
||||
static_for_impl<NBegin, NEnd - NBegin, forwarder(Increment)>{}(f);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NLoop>
|
||||
struct static_const_reduce_n
|
||||
{
|
||||
// signature of F: F(Number<I>)
|
||||
template <class F, class Reduce>
|
||||
__host__ __device__ constexpr auto operator()(F f, Reduce r) const
|
||||
{
|
||||
static_assert(NLoop > 1, "out-of-range");
|
||||
|
||||
constexpr auto a = f(Number<NLoop - 1>{});
|
||||
auto b = static_const_reduce_n<NLoop - 1>{}(f, r); // TODO: cannot use constexpr here, weird
|
||||
return r(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_const_reduce_n<1>
|
||||
{
|
||||
template <class F, class Reduce>
|
||||
__host__ __device__ constexpr auto operator()(F f, Reduce) const
|
||||
{
|
||||
return f(Number<0>{});
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
template<class F>
|
||||
__host__ __device__ constexpr auto unpacker(F f)
|
||||
{
|
||||
return [=](auto xs_array){ f(xs...); };
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_4d_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
#include "threadwise_direct_convolution.hip.hpp"
|
||||
|
||||
template <index_t GridSize,
|
||||
@@ -229,7 +229,7 @@ struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw
|
||||
}
|
||||
|
||||
// copy output tensor from register to global mem
|
||||
threadwise_nd_tensor_copy(
|
||||
threadwise_tensor_slice_copy(
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread,
|
||||
out_nkhw_global_desc,
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
|
||||
@@ -325,7 +325,7 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(
|
||||
threadwise_tensor_slice_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
@@ -375,7 +375,7 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(
|
||||
threadwise_tensor_slice_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_3d_tensor_op.hip.hpp"
|
||||
#include "blockwise_4d_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
|
||||
@@ -358,7 +358,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(
|
||||
threadwise_tensor_slice_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
@@ -408,7 +408,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(
|
||||
threadwise_tensor_slice_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_3d_tensor_op.hip.hpp"
|
||||
#include "blockwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "blockwise_tensor_slice_op.hip.hpp"
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
|
||||
@@ -127,18 +127,18 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
|
||||
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder =
|
||||
BlockwiseNdTensorCopyReorder_v3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WiPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
const auto blockwise_in_copy_reorder = BlockwiseTensorSliceReorderCopy_v3<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WiPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_4d_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
|
||||
@@ -347,7 +347,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(
|
||||
threadwise_tensor_slice_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
@@ -397,7 +397,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(
|
||||
threadwise_tensor_slice_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_4d_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
|
||||
@@ -408,7 +408,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(
|
||||
threadwise_tensor_slice_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
@@ -458,7 +458,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(
|
||||
threadwise_tensor_slice_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "blockwise_tensor_slice_op.hip.hpp"
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
|
||||
@@ -131,18 +131,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
|
||||
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder =
|
||||
BlockwiseNdTensorCopyReorder_v3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
const auto blockwise_in_copy_reorder = BlockwiseTensorSliceReorderCopy_v3<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, KPerBlock]
|
||||
@@ -407,7 +407,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(
|
||||
threadwise_tensor_slice_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
@@ -457,7 +457,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(
|
||||
threadwise_tensor_slice_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "blockwise_tensor_slice_op.hip.hpp"
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
|
||||
@@ -131,18 +131,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
|
||||
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
|
||||
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder =
|
||||
BlockwiseNdTensorCopyReorder_v3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
const auto blockwise_in_copy_reorder = BlockwiseTensorSliceReorderCopy_v3<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, KPerBlock]
|
||||
@@ -409,7 +409,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
|
||||
|
||||
constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{};
|
||||
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
@@ -458,7 +458,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
|
||||
|
||||
constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{};
|
||||
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "blockwise_tensor_slice_op.hip.hpp"
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
|
||||
@@ -130,18 +130,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
|
||||
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
|
||||
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder =
|
||||
BlockwiseNdTensorCopyReorder_v3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
const auto blockwise_in_copy_reorder = BlockwiseTensorSliceReorderCopy_v3<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, KPerBlock]
|
||||
@@ -390,7 +390,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(
|
||||
threadwise_tensor_slice_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
@@ -440,7 +440,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(
|
||||
threadwise_tensor_slice_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "blockwise_tensor_slice_op.hip.hpp"
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
|
||||
@@ -73,7 +73,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
// divide block work: [N, K, Ho, Wo]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
@@ -128,18 +128,18 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
|
||||
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder =
|
||||
BlockwiseNdTensorCopyReorder_v3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
const auto blockwise_in_copy_reorder = BlockwiseTensorSliceReorderCopy_v3<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, KPerBlock]
|
||||
@@ -390,7 +390,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
|
||||
constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{};
|
||||
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
@@ -439,7 +439,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
|
||||
constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{};
|
||||
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "blockwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "threadwise_2d_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
#include "blockwise_gemm.hip.hpp"
|
||||
|
||||
// define B = flatten(N, Hi, Wi)
|
||||
|
||||
@@ -0,0 +1,309 @@
|
||||
#pragma once
|
||||
#include "common.hip.hpp"
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_gemm.hip.hpp"
|
||||
|
||||
// define B = merge(N, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t N1,
|
||||
index_t N2,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB>
|
||||
struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// this is a mess
|
||||
// TODO: more elegent way of specifying (or calculating) performance variables
|
||||
static_assert(N2 == GemmNPerThreadSubC, "wrong!");
|
||||
static_assert(KPerBlock ==
|
||||
N1 * GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster,
|
||||
"wrong!");
|
||||
static_assert(
|
||||
KPerBlock % (N1 * GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
|
||||
|
||||
constexpr index_t N0 = N / (N1 * N2);
|
||||
|
||||
constexpr index_t B = N0 * Ho * Wo;
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0,
|
||||
C % CPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// memory layout descriptor in device memory [N0, N1, N2, C, H, W]
|
||||
constexpr auto in_n0_n1_n2_c_h_w_global_mem_desc =
|
||||
in_n_c_h_w_global_desc.Fold(I0, Sequence<N1, N2>{});
|
||||
|
||||
// merged tensor descriptor in device memory [N1, N2, C, B], src of blockwise copy
|
||||
constexpr auto in_n1_n2_c_b_global_merged_desc =
|
||||
in_n0_n1_n2_c_h_w_global_desc.ReorderGivenNew2Old(Sequence<1, 2, 3, 0, 4, 5>{})
|
||||
.Slice(I4, Number<Ho>{})
|
||||
.Slice(I5, Number<Wo>{})
|
||||
.Merge(I3, I5);
|
||||
|
||||
// memory layout descriptor in LDS [C, N1, B, N2]
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_c_n1_b_n2_block_mem_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
|
||||
|
||||
// tensor descriptor in LDS [N1, N2, C, B], dst of blockwise copy
|
||||
constexpr auto in_n1_n2_c_b_block_desc =
|
||||
in_c_n1_b_n2_block_mem_desc.ReorderGivenNew2Old(Sequence<1, 3, 0, 2>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not satisfied");
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
const auto blockwise_in_copy = BlockwiseTensorSliceCopy_generic_v1<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_n1_n2_c_b_global_merged_desc),
|
||||
decltype(in_n1_n2_c_b_block_desc),
|
||||
Sequence<N1, N2, CPerBlock, BPerBlock>,
|
||||
InBlockCopySubLengths_N1_N2_C_B,
|
||||
InBlockCopyClusterLengths_N1_N2_C_B,
|
||||
Sequence<2, 0, 1, 3>, // thread_arrange_order [C, N1, N2, B]
|
||||
Sequence<0, 1, 2, 3>, // src_access_order [N1, N2, C, B]
|
||||
Sequence<2, 0, 3, 1>, // dst_access_order [C, N1, B, N2]
|
||||
>({0, 0, 0, b_block_data_on_global}, {0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(Sequence<0, 3>{});
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slicing a tensor
|
||||
// this copy operator already have tensor offset built-in
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>({0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[CPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[CPerBlocl, N1 * BPerBlock * N2] is in LDS
|
||||
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_n1bn2_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<N1 * BPerBlock * N2>{},
|
||||
Number<in_c_n1_b_n2_block_mem_desc.GetStride(I0)>{});
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster),
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_n1bn2_block_mtx_desc),
|
||||
decltype(c_k0k2_n1n2_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = mod_conv::max(InBlockReorderDataPerWrite_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
in_c_n1_b_n2_block_mem_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(out_k0_k1_k2_n1_n0_h_w_n2_thread_desc, p_out_thread);
|
||||
|
||||
// do work
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
// calculate origin of block input and weight tensor on global memory
|
||||
const Float* p_in_block_on_global =
|
||||
p_in_global + in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, x);
|
||||
|
||||
const Float* p_wei_block_on_global =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, 0);
|
||||
|
||||
for(index_t
|
||||
c_block_data_on_global = 0;
|
||||
c_block_data_on_global < C;
|
||||
c_block_data_on_global += CPerBlock,
|
||||
p_in_block_ont_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
|
||||
p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
blockwise_in_copy.run(p_in_block_on_global, p_in_block);
|
||||
blockwise_wei_copy.run(p_wei_block_on_global, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_gemm.run(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t K0 = K / (K1 * K2);
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output tensor (also, memory layout) descriptor in register, src of threadwise
|
||||
// copy
|
||||
constexpr auto out_k0_k1_k2_n1_b_n2_thread_mem_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
|
||||
|
||||
// output memory layout descriptor in device memory
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global.Fold(I1, Sequence<K1, K2>{}).Fold(I0, Sequence<N1, N2>{});
|
||||
|
||||
// output merged tensor descriptor in device memory, dst of threadwise copy
|
||||
constexpr auto out_k0_k1_k2_n1_b_n2_global_merged_desc =
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc
|
||||
.ReorderGivenNew2Old(Sequence<3, 4, 5, 1, 0, 6, 7, 2>{})
|
||||
.Merge(I4, I6);
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
// origin of thread tensor on global
|
||||
const index_t k_thread_data_on_global k_block_data_on_global +
|
||||
c_thread_mtx_on_block.row;
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col;
|
||||
|
||||
// output merged global tensor descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_n1_b_n2_global_merged_desc =
|
||||
out_k0_k1_k2_n1_b_n2_global_merged_desc.Unfold(I1, I2);
|
||||
|
||||
// origin of thread tensor in global memory
|
||||
const index_t p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_n1_b_n2_global_merged_desc.Get1dIndex(
|
||||
k_thread_data_on_global, 0, 0, 0); // dst origin on merged global tensor
|
||||
|
||||
// copy
|
||||
threadwise_tensor_slice_copy_generic(
|
||||
out_k0_k1_k2_n1_b_n2_thread_mem_desc, // src thread tensor (in register) descriptor
|
||||
p_out_thread, // origin of src
|
||||
{0, 0, 0, 0, 0, 0}, // starting point of slice, w.r.t. origin of src
|
||||
out_k0_k1_k2_n1_b_n2_global_merged_desc, // dst global merged tensor (in device mem)
|
||||
// descriptor
|
||||
p_out_thread_on_global, // origin of dst
|
||||
{0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
b_thread_data_on_global,
|
||||
0}, // starting point of slice w.r.t. origin of dst
|
||||
out_k0_k1_k2_n1_b_n2_thread_desc.GetLengths(), // slice lengths
|
||||
Sequence<2, 3, 4, 0, 5, 1>{} // order of dimension access
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -85,7 +85,7 @@ struct TensorDescriptor
|
||||
{
|
||||
}
|
||||
|
||||
std::size_t GetDimension() const;
|
||||
std::size_t GetNumOfDimension() const;
|
||||
std::size_t GetElementSize() const;
|
||||
std::size_t GetElementSpace() const;
|
||||
|
||||
@@ -95,7 +95,7 @@ struct TensorDescriptor
|
||||
template <class... Is>
|
||||
std::size_t Get1dIndex(Is... is) const
|
||||
{
|
||||
assert(sizeof...(Is) == this->GetDimension());
|
||||
assert(sizeof...(Is) == this->GetNumOfDimension());
|
||||
std::initializer_list<std::size_t> iss{static_cast<std::size_t>(is)...};
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
@@ -206,7 +206,7 @@ struct Tensor
|
||||
template <class G>
|
||||
void GenerateTensorValue(G g, std::size_t num_thread = 1)
|
||||
{
|
||||
switch(mDesc.GetDimension())
|
||||
switch(mDesc.GetNumOfDimension())
|
||||
{
|
||||
case 1:
|
||||
{
|
||||
|
||||
@@ -88,7 +88,7 @@ threadwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
|
||||
#if 0 // replaced threadwise_nd_tensor_copy
|
||||
#if 0 // replaced threadwise_tensor_slice_copy
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
__device__ void threadwise_2d_tensor_copy(
|
||||
SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_tensor_slice_op.hip.hpp"
|
||||
|
||||
// optimized for scenario if p_in, p_wei, p_out are in register
|
||||
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
|
||||
@@ -85,11 +85,11 @@ __device__ void threadwise_direct_convolution_2(InDesc,
|
||||
TInWei p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_nd_tensor_copy(
|
||||
threadwise_tensor_slice_copy(
|
||||
in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths(), Number<1>{});
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_nd_tensor_copy(
|
||||
threadwise_tensor_slice_copy(
|
||||
wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths(), Number<1>{});
|
||||
|
||||
// do convolution
|
||||
|
||||
@@ -1,4 +1,19 @@
|
||||
#pragma once
|
||||
#include "common.hip.hpp"
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
|
||||
template <class Float, class Matrix>
|
||||
__device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread)
|
||||
{
|
||||
for(index_t i = 0; i < Matrix::NRow(); ++i)
|
||||
{
|
||||
for(index_t j = 0; j < Matrix::NCol(); ++j)
|
||||
{
|
||||
const index_t id = Matrix::Get1dIndex(i, j);
|
||||
p_thread[id] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Float,
|
||||
class SrcMatrix,
|
||||
@@ -64,9 +79,9 @@ __device__ void threadwise_gemm(MatrixA,
|
||||
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
for(index_t i = 0; i < M; i++)
|
||||
for(index_t i = 0; i < M; ++i)
|
||||
{
|
||||
for(index_t j = 0; j < N; j++)
|
||||
for(index_t j = 0; j < N; ++j)
|
||||
{
|
||||
const index_t aindex = a_mtx.Get1dIndex(k, i); // A is transposed
|
||||
const index_t bindex = b_mtx.Get1dIndex(k, j);
|
||||
|
||||
@@ -3,18 +3,18 @@
|
||||
|
||||
// need to assume src and dst is aligned
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
|
||||
__device__ void threadwise_nd_tensor_copy(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
Number<DataPerRead>)
|
||||
__device__ void threadwise_tensor_slice_copy(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
Number<DataPerRead>)
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
constexpr index_t nDim = SrcOpLengths::GetSize();
|
||||
|
||||
static_assert(SrcDesc{}.GetDimension() == nDim && DstDesc{}.GetDimension() == nDim,
|
||||
static_assert(SrcDesc{}.GetNumOfDimension() == nDim && DstDesc{}.GetNumOfDimension() == nDim,
|
||||
"wrong! dimension not consistent");
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
@@ -63,7 +63,7 @@ __device__ void threadwise_nd_tensor_copy(SrcDesc,
|
||||
});
|
||||
}
|
||||
|
||||
// write in order of src
|
||||
// access in order of src
|
||||
template <class SrcData,
|
||||
class DstData,
|
||||
class SrcDesc,
|
||||
@@ -71,12 +71,12 @@ template <class SrcData,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v1(SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v1(SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
@@ -92,7 +92,7 @@ threadwise_nd_tensor_copy_reorder_given_dst2src_v1(SrcDesc,
|
||||
});
|
||||
}
|
||||
|
||||
// write in order of dst
|
||||
// access in order of dst
|
||||
template <class SrcData,
|
||||
class DstData,
|
||||
class SrcDesc,
|
||||
@@ -100,12 +100,12 @@ template <class SrcData,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
@@ -123,20 +123,22 @@ threadwise_nd_tensor_copy_reorder_given_dst2src_v2(SrcDesc,
|
||||
});
|
||||
}
|
||||
|
||||
// write in order of dst
|
||||
// access in order of dst
|
||||
// manually pack data into vector before write
|
||||
template <class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src,
|
||||
index_t DstDataPerWrite>
|
||||
__device__ void threadwise_nd_tensor_copy_reorder_given_dst2src_v3(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src,
|
||||
Number<DstDataPerWrite>)
|
||||
__device__ void
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src,
|
||||
Number<DstDataPerWrite>)
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DstDataPerWrite>::MemoryType;
|
||||
|
||||
@@ -190,3 +192,17 @@ __device__ void threadwise_nd_tensor_copy_reorder_given_dst2src_v3(SrcDesc,
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <class Float, class SrcDesc, class DstDesc, class SliceLengths, class DimAccessOrder>
|
||||
__device__ void
|
||||
threadwise_tensor_slice_copy_generic(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
Array<index_t, SrcDesc::GetNumOfDimension()> src_multi_offset,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
Array<index_t, DstDesc::GetNumOfDimension()> dst_multi_offset,
|
||||
SliceLengths,
|
||||
DimAccessOrder)
|
||||
{
|
||||
// not implemented
|
||||
}
|
||||
@@ -25,7 +25,7 @@ void TensorDescriptor::CalculateStrides()
|
||||
mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies<std::size_t>());
|
||||
}
|
||||
|
||||
std::size_t TensorDescriptor::GetDimension() const { return mLens.size(); }
|
||||
std::size_t TensorDescriptor::GetNumOfDimension() const { return mLens.size(); }
|
||||
|
||||
std::size_t TensorDescriptor::GetElementSize() const
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user