mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
implicit gemm v1r3 nchw_cyxk_nkhw
This commit is contained in:
@@ -24,7 +24,7 @@ struct Array
|
||||
{
|
||||
Array<TData, NSize + 1> new_array;
|
||||
|
||||
static_for<0, NSize, 1>{}([=](auto I) {
|
||||
static_for<0, NSize, 1>{}([&](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
new_array[i] = mData[i];
|
||||
});
|
||||
|
||||
@@ -137,11 +137,16 @@ struct ConstantTensorDescriptor
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ static constexpr index_t Get1dIndex(Sequence<Is...> multi_id)
|
||||
__host__ __device__ static constexpr index_t Get1dIndex(Sequence<Is...> /*multi_id*/)
|
||||
{
|
||||
static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent");
|
||||
|
||||
return Get1dIndex(Is...);
|
||||
constexpr auto multi_id = Sequence<Is...>{};
|
||||
|
||||
constexpr auto seq_tmp =
|
||||
transform_sequences(mod_conv::multiplies<index_t>{}, multi_id, GetStrides());
|
||||
|
||||
return accumulate_on_sequence(seq_tmp, mod_conv::plus<index_t>{}, Number<0>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static Array<index_t, nDim> GetMultiIndex(index_t id)
|
||||
|
||||
@@ -246,7 +246,8 @@ struct accumulate_on_sequence_f
|
||||
};
|
||||
|
||||
template <class Seq, class Reduce, index_t I>
|
||||
__host__ __device__ constexpr index_t accumulate_on_sequence(Seq, Reduce, Number<I>)
|
||||
__host__ __device__ constexpr index_t
|
||||
accumulate_on_sequence(Seq, Reduce, Number<I> /*initial_value*/)
|
||||
{
|
||||
constexpr index_t a =
|
||||
static_const_reduce_n<Seq::mSize>{}(accumulate_on_sequence_f<Seq>{}, Reduce{});
|
||||
|
||||
@@ -471,7 +471,6 @@ struct Blockwise2dTensorCopy3
|
||||
DstDesc{}.GetStride(I0) % DataPerRead == 0,
|
||||
"src and dst stride should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
|
||||
@@ -761,339 +761,3 @@ struct Blockwise4dTensorCopyReorder1
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcLengths,
|
||||
class SrcSubLengths,
|
||||
class SrcClusterLengths,
|
||||
class MapDst2Src,
|
||||
class MapThreadCluster2SrcCluster,
|
||||
index_t SrcDataPerRead,
|
||||
index_t DstDataPerWrite>
|
||||
struct Blockwise4dTensorCopyReorder3
|
||||
{
|
||||
static constexpr index_t nDim = SrcLengths::GetSize();
|
||||
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__ Blockwise4dTensorCopyReorder3()
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr auto src_lengths = SrcLengths{};
|
||||
|
||||
constexpr auto map_dst2src = MapDst2Src{};
|
||||
|
||||
constexpr auto src_sub_lengths = SrcSubLengths{};
|
||||
constexpr auto dst_sub_lengths = src_sub_lengths.ReorderGivenNew2Old(map_dst2src);
|
||||
|
||||
constexpr auto map_thread_cluster_2_src_cluster = MapThreadCluster2SrcCluster{};
|
||||
|
||||
constexpr auto src_cluster_lengths = SrcClusterLengths{};
|
||||
constexpr auto thread_cluster_lengths =
|
||||
src_cluster_lengths.ReorderGivenNew2Old(map_thread_cluster_2_src_cluster);
|
||||
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor(thread_cluster_lengths);
|
||||
|
||||
// sanity check: data type
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float for now!\n");
|
||||
|
||||
// sanity check: nDim
|
||||
static_assert(SrcDesc::GetDimension() == nDim && DstDesc::GetDimension() == nDim &&
|
||||
SrcLengths::GetSize() == nDim && SrcSubLengths::GetSize() == nDim &&
|
||||
SrcClusterLengths::GetSize() == nDim && MapDst2Src::GetSize() == nDim &&
|
||||
MapThreadCluster2SrcCluster::GetSize() == nDim,
|
||||
"wrong! nDim is not consistent\n");
|
||||
|
||||
// sanity check: BlockSize
|
||||
constexpr index_t num_active_thread = thread_cluster_desc.GetElementSize();
|
||||
|
||||
static_assert(BlockSize >= num_active_thread,
|
||||
"wrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
// sanity check: work division
|
||||
static_for<0, nDim, 1>{}([](auto IDim) {
|
||||
constexpr auto I = decltype(IDim){};
|
||||
constexpr index_t src_len = src_lengths.Get(I);
|
||||
constexpr index_t src_sub_len = src_sub_lengths.Get(I);
|
||||
constexpr index_t src_cluster_len = src_cluster_lengths.Get(I);
|
||||
static_assert(src_len % (src_sub_len * src_cluster_len) == 0,
|
||||
"wrong! cannot evenly divide Src tensor lengths");
|
||||
});
|
||||
|
||||
// sanity check: src read
|
||||
static_assert(SrcDataPerRead == 1 || SrcDataPerRead == 2 || SrcDataPerRead == 4,
|
||||
"wrong! only support SrcDataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDataPerRead == 1 || src_desc.GetStride(Number<nDim - 1>{}) == 1,
|
||||
"wrong! only support src.stride(nDim-1) == 1 if SrcDataPerRead > 1!\n");
|
||||
|
||||
static_assert(src_sub_lengths.Get(Number<nDim - 1>{}) % SrcDataPerRead == 0,
|
||||
"wrong! src_sub_lengths[nDim-1] % SrcDataPerRead != 0\n");
|
||||
|
||||
static_assert(src_desc.GetStride(Number<nDim - 2>{}) % SrcDataPerRead == 0,
|
||||
"wrong! should satisfy src_desc.stride(nDim-2) % SrcDataPerRead == 0, to "
|
||||
"keep alignment");
|
||||
|
||||
// sanity check: dst write
|
||||
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4,
|
||||
"wrong! only support DstDataPerWrite == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(DstDataPerWrite == 1 || dst_desc.GetStride(Number<nDim - 1>{}) == 1,
|
||||
"wrong! only support dst.stride(nDim-1) == 1 if DstDataPerWrite > 1!\n");
|
||||
|
||||
static_assert(dst_sub_lengths.Get(Number<nDim - 1>{}) % DstDataPerWrite == 0,
|
||||
"wrong! dst_sub_lengths[nDim-1] % DstDataPerWrite != 0\n");
|
||||
|
||||
static_assert(dst_desc.GetStride(Number<nDim - 2>{}) % DstDataPerWrite == 0,
|
||||
"wrong! should satisfy dst_desc.stride(nDim-2) % DstDataPerWrite == 0, to "
|
||||
"keep alignment");
|
||||
|
||||
// start dividing work
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const auto thread_multi_id = thread_cluster_desc.GetMultiIndex(get_thread_local_1d_id());
|
||||
|
||||
// compiler: thread_multi_id, src_data_multi_id, dst_data_multi_id, will use separate
|
||||
// regsiters, or only one copy???
|
||||
auto src_data_multi_id =
|
||||
reorder_array_given_old2new(thread_multi_id, map_thread_cluster_2_src_cluster);
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr auto I = decltype(IDim){};
|
||||
constexpr index_t i = I.Get();
|
||||
// compiler: will it really compute index here, or be associated with Get1dIndex and
|
||||
// optimized away???
|
||||
src_data_multi_id[i] *= src_sub_lengths.Get(I);
|
||||
});
|
||||
|
||||
// compiler: will it really compute index here, or be associated with Get1dIndex and
|
||||
// optimized away???
|
||||
const auto dst_data_multi_id = reorder_array_given_new2old(src_data_multi_id, map_dst2src);
|
||||
|
||||
mSrcMyThreadOffset = src_desc.Get1dIndex(src_data_multi_id);
|
||||
mDstMyThreadOffset = dst_desc.Get1dIndex(dst_data_multi_id);
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("tid %5u, "
|
||||
"thread_multi_id %5u %5u %5u %5u, "
|
||||
"src_data_multi_id %5u %5u %5u %5u, "
|
||||
"dst_data_multi_id %5u %5u %5u %5u, "
|
||||
"mSrcMyThreadOffset %u, mDstMyThreadOffset %u\n",
|
||||
get_thread_local_1d_id(),
|
||||
thread_multi_id[0],
|
||||
thread_multi_id[1],
|
||||
thread_multi_id[2],
|
||||
thread_multi_id[3],
|
||||
src_data_multi_id[0],
|
||||
src_data_multi_id[1],
|
||||
src_data_multi_id[2],
|
||||
src_data_multi_id[3],
|
||||
dst_data_multi_id[0],
|
||||
dst_data_multi_id[1],
|
||||
dst_data_multi_id[2],
|
||||
dst_data_multi_id[3],
|
||||
mSrcMyThreadOffset,
|
||||
mDstMyThreadOffset);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterClipboardSize()
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
|
||||
|
||||
constexpr auto cluster_per_dims =
|
||||
transform_sequences(mod_conv::integer_divide_ceiler<index_t>{},
|
||||
SrcLengths{},
|
||||
src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths);
|
||||
|
||||
return thread_tensor_desc.GetElementSpace();
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
|
||||
|
||||
constexpr auto cluster_per_dims =
|
||||
transform_sequences(mod_conv::integer_divide_ceiler<index_t>{},
|
||||
SrcLengths{},
|
||||
src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths);
|
||||
|
||||
constexpr auto thread_sub_tensor_desc =
|
||||
make_ConstantTensorDescriptor(SrcClusterLengths{}, thread_tensor_desc.GetStrides());
|
||||
|
||||
#if 1
|
||||
for(index_t icluster_d0 = 0; icluster_d0 < cluster_per_dims.Get(I0); ++icluster_d0)
|
||||
{
|
||||
for(index_t icluster_d1 = 0; icluster_d1 < cluster_per_dims.Get(I1); ++icluster_d1)
|
||||
{
|
||||
for(index_t icluster_d2 = 0; icluster_d2 < cluster_per_dims.Get(I2); ++icluster_d2)
|
||||
{
|
||||
for(index_t icluster_d3 = 0; icluster_d3 < cluster_per_dims.Get(I3);
|
||||
++icluster_d3)
|
||||
{
|
||||
const index_t src_offset = SrcDesc{}.Get1dIndex(
|
||||
icluster_d0 * src_data_per_cluster_per_dims.Get(I0),
|
||||
icluster_d1 * src_data_per_cluster_per_dims.Get(I1),
|
||||
icluster_d2 * src_data_per_cluster_per_dims.Get(I2),
|
||||
icluster_d3 * src_data_per_cluster_per_dims.Get(I3));
|
||||
|
||||
const index_t clipboard_offset = thread_tensor_desc.Get1dIndex(
|
||||
icluster_d0 * thread_sub_tensor_lengths.Get(I0),
|
||||
icluster_d1 * thread_sub_tensor_lengths.Get(I1),
|
||||
icluster_d2 * thread_sub_tensor_lengths.Get(I2),
|
||||
icluster_d3 * thread_sub_tensor_lengths.Get(I3));
|
||||
|
||||
threadwise_nd_tensor_copy(SrcDesc{},
|
||||
p_src + src_offset + mSrcMyThreadOffset,
|
||||
thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
thread_sub_tensor_lengths,
|
||||
Number<SrcDataPerRead>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
static_ford<decltype(cluster_per_dims)>{}([=](auto cluster_ids) {
|
||||
|
||||
});
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("tid %5u, "
|
||||
"data: %f %f %f %f %f %f %f %f\n",
|
||||
get_thread_local_1d_id(),
|
||||
p_clipboard[0],
|
||||
p_clipboard[1],
|
||||
p_clipboard[2],
|
||||
p_clipboard[3],
|
||||
p_clipboard[4],
|
||||
p_clipboard[5],
|
||||
p_clipboard[6],
|
||||
p_clipboard[7]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
|
||||
|
||||
constexpr auto cluster_per_dims =
|
||||
transform_sequences(mod_conv::integer_divide_ceiler<index_t>{},
|
||||
SrcLengths{},
|
||||
src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths);
|
||||
|
||||
constexpr auto thread_sub_tensor_desc =
|
||||
make_ConstantTensorDescriptor(SrcClusterLengths{}, thread_tensor_desc.GetStrides());
|
||||
|
||||
for(index_t icluster_d0 = 0; icluster_d0 < cluster_per_dims.Get(I0); ++icluster_d0)
|
||||
{
|
||||
for(index_t icluster_d1 = 0; icluster_d1 < cluster_per_dims.Get(I1); ++icluster_d1)
|
||||
{
|
||||
for(index_t icluster_d2 = 0; icluster_d2 < cluster_per_dims.Get(I2); ++icluster_d2)
|
||||
{
|
||||
for(index_t icluster_d3 = 0; icluster_d3 < cluster_per_dims.Get(I3);
|
||||
++icluster_d3)
|
||||
{
|
||||
const index_t clipboard_offset = thread_tensor_desc.Get1dIndex(
|
||||
icluster_d0 * thread_sub_tensor_lengths.Get(I0),
|
||||
icluster_d1 * thread_sub_tensor_lengths.Get(I1),
|
||||
icluster_d2 * thread_sub_tensor_lengths.Get(I2),
|
||||
icluster_d3 * thread_sub_tensor_lengths.Get(I3));
|
||||
|
||||
const auto dst_multi_id = reorder_array_given_new2old(
|
||||
Array<index_t, nDim>{
|
||||
icluster_d0 * src_data_per_cluster_per_dims.Get(I0),
|
||||
icluster_d1 * src_data_per_cluster_per_dims.Get(I1),
|
||||
icluster_d2 * src_data_per_cluster_per_dims.Get(I2),
|
||||
icluster_d3 * src_data_per_cluster_per_dims.Get(I3)},
|
||||
MapDst2Src{});
|
||||
|
||||
const index_t dst_offset = DstDesc{}.Get1dIndex(dst_multi_id);
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("tid %5u, "
|
||||
"clipboard_offsetm %5u, dst_offset %5u\n",
|
||||
get_thread_local_1d_id(),
|
||||
clipboard_offset,
|
||||
dst_offset);
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_4d_tensor_copy_reorder_given_dst2src_v2(
|
||||
thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
DstDesc{},
|
||||
p_dst + dst_offset + mDstMyThreadOffset,
|
||||
thread_sub_tensor_lengths,
|
||||
MapDst2Src{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
Float p_clipboard[GetRegisterClipboardSize()];
|
||||
|
||||
RunLoadRegisterClipboard(p_src, p_clipboard);
|
||||
RunStoreRegisterClipboard(p_clipboard, p_dst);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -53,7 +53,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
|
||||
constexpr index_t M = a_block_mtx.NCol(); // A is transposed
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
@@ -114,8 +113,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
|
||||
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
|
||||
{
|
||||
constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
|
||||
|
||||
constexpr index_t ThreadPerLevel1Cluster =
|
||||
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
|
||||
|
||||
|
||||
252
src/include/blockwise_nd_tensor_op.hip.hpp
Normal file
252
src/include/blockwise_nd_tensor_op.hip.hpp
Normal file
@@ -0,0 +1,252 @@
|
||||
#pragma once
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcLengths,
|
||||
class SrcSubLengths,
|
||||
class SrcClusterLengths,
|
||||
class MapDst2Src,
|
||||
class MapThreadCluster2SrcCluster,
|
||||
index_t SrcDataPerRead,
|
||||
index_t DstDataPerWrite>
|
||||
struct BlockwiseNdTensorCopyReorder_v3
|
||||
{
|
||||
static constexpr index_t nDim = SrcLengths::GetSize();
|
||||
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__ BlockwiseNdTensorCopyReorder_v3()
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr auto src_lengths = SrcLengths{};
|
||||
|
||||
constexpr auto map_dst2src = MapDst2Src{};
|
||||
|
||||
constexpr auto src_sub_lengths = SrcSubLengths{};
|
||||
constexpr auto dst_sub_lengths = src_sub_lengths.ReorderGivenNew2Old(map_dst2src);
|
||||
|
||||
constexpr auto map_thread_cluster_2_src_cluster = MapThreadCluster2SrcCluster{};
|
||||
|
||||
constexpr auto src_cluster_lengths = SrcClusterLengths{};
|
||||
constexpr auto thread_cluster_lengths =
|
||||
src_cluster_lengths.ReorderGivenNew2Old(map_thread_cluster_2_src_cluster);
|
||||
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor(thread_cluster_lengths);
|
||||
|
||||
// sanity check: data type
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float for now!\n");
|
||||
|
||||
// sanity check: nDim
|
||||
static_assert(SrcDesc::GetDimension() == nDim && DstDesc::GetDimension() == nDim &&
|
||||
SrcLengths::GetSize() == nDim && SrcSubLengths::GetSize() == nDim &&
|
||||
SrcClusterLengths::GetSize() == nDim && MapDst2Src::GetSize() == nDim &&
|
||||
MapThreadCluster2SrcCluster::GetSize() == nDim,
|
||||
"wrong! nDim is not consistent\n");
|
||||
|
||||
// sanity check: BlockSize
|
||||
constexpr index_t num_active_thread = thread_cluster_desc.GetElementSize();
|
||||
|
||||
static_assert(BlockSize >= num_active_thread,
|
||||
"wrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
// sanity check: work division
|
||||
static_for<0, nDim, 1>{}([](auto IDim) {
|
||||
constexpr auto I = decltype(IDim){};
|
||||
constexpr index_t src_len = src_lengths.Get(I);
|
||||
constexpr index_t src_sub_len = src_sub_lengths.Get(I);
|
||||
constexpr index_t src_cluster_len = src_cluster_lengths.Get(I);
|
||||
static_assert(src_len % (src_sub_len * src_cluster_len) == 0,
|
||||
"wrong! cannot evenly divide Src tensor lengths");
|
||||
});
|
||||
|
||||
// sanity check: src read
|
||||
static_assert(SrcDataPerRead == 1 || SrcDataPerRead == 2 || SrcDataPerRead == 4,
|
||||
"wrong! only support SrcDataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDataPerRead == 1 || src_desc.GetStride(Number<nDim - 1>{}) == 1,
|
||||
"wrong! only support src.stride(nDim-1) == 1 if SrcDataPerRead > 1!\n");
|
||||
|
||||
static_assert(src_sub_lengths.Get(Number<nDim - 1>{}) % SrcDataPerRead == 0,
|
||||
"wrong! src_sub_lengths[nDim-1] % SrcDataPerRead != 0\n");
|
||||
|
||||
static_assert(src_desc.GetStride(Number<nDim - 2>{}) % SrcDataPerRead == 0,
|
||||
"wrong! should satisfy src_desc.stride(nDim-2) % SrcDataPerRead == 0, to "
|
||||
"keep alignment");
|
||||
|
||||
// sanity check: dst write
|
||||
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4,
|
||||
"wrong! only support DstDataPerWrite == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(DstDataPerWrite == 1 || dst_desc.GetStride(Number<nDim - 1>{}) == 1,
|
||||
"wrong! only support dst.stride(nDim-1) == 1 if DstDataPerWrite > 1!\n");
|
||||
|
||||
static_assert(dst_sub_lengths.Get(Number<nDim - 1>{}) % DstDataPerWrite == 0,
|
||||
"wrong! dst_sub_lengths[nDim-1] % DstDataPerWrite != 0\n");
|
||||
|
||||
static_assert(dst_desc.GetStride(Number<nDim - 2>{}) % DstDataPerWrite == 0,
|
||||
"wrong! should satisfy dst_desc.stride(nDim-2) % DstDataPerWrite == 0, to "
|
||||
"keep alignment");
|
||||
|
||||
// start dividing work
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const auto thread_multi_id = thread_cluster_desc.GetMultiIndex(get_thread_local_1d_id());
|
||||
|
||||
// compiler: thread_multi_id, src_data_multi_id, dst_data_multi_id, will use separate
|
||||
// regsiters, or only one copy???
|
||||
auto src_data_multi_id =
|
||||
reorder_array_given_old2new(thread_multi_id, map_thread_cluster_2_src_cluster);
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr auto I = decltype(IDim){};
|
||||
constexpr index_t i = I.Get();
|
||||
// compiler: will it really compute index here, or be associated with Get1dIndex and
|
||||
// optimized away???
|
||||
src_data_multi_id[i] *= src_sub_lengths.Get(I);
|
||||
});
|
||||
|
||||
// compiler: will it really compute index here, or be associated with Get1dIndex and
|
||||
// optimized away???
|
||||
const auto dst_data_multi_id = reorder_array_given_new2old(src_data_multi_id, map_dst2src);
|
||||
|
||||
mSrcMyThreadOffset = src_desc.Get1dIndex(src_data_multi_id);
|
||||
mDstMyThreadOffset = dst_desc.Get1dIndex(dst_data_multi_id);
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterClipboardSize()
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
|
||||
|
||||
constexpr auto repeat_lengths =
|
||||
transform_sequences(mod_conv::integer_divide_ceiler<index_t>{},
|
||||
SrcLengths{},
|
||||
src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, repeat_lengths);
|
||||
|
||||
constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths);
|
||||
|
||||
return thread_tensor_desc.GetElementSpace();
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
|
||||
|
||||
constexpr auto repeat_lengths =
|
||||
transform_sequences(mod_conv::integer_divide_ceiler<index_t>{},
|
||||
SrcLengths{},
|
||||
src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, repeat_lengths);
|
||||
|
||||
constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths);
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
|
||||
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
|
||||
|
||||
constexpr auto src_data_multi_id = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, repeat_multi_id, src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto clipboard_data_multi_id = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, repeat_multi_id, thread_sub_tensor_lengths);
|
||||
|
||||
constexpr index_t src_offset = SrcDesc{}.Get1dIndex(src_data_multi_id);
|
||||
constexpr index_t clipboard_offset =
|
||||
thread_tensor_desc.Get1dIndex(clipboard_data_multi_id);
|
||||
|
||||
threadwise_nd_tensor_copy(SrcDesc{},
|
||||
p_src + src_offset + mSrcMyThreadOffset,
|
||||
thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
thread_sub_tensor_lengths,
|
||||
Number<SrcDataPerRead>{});
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
|
||||
|
||||
constexpr auto repeat_lengths =
|
||||
transform_sequences(mod_conv::integer_divide_ceiler<index_t>{},
|
||||
SrcLengths{},
|
||||
src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, repeat_lengths);
|
||||
|
||||
constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths);
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
|
||||
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
|
||||
|
||||
constexpr auto clipboard_data_multi_id = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, repeat_multi_id, thread_sub_tensor_lengths);
|
||||
|
||||
constexpr auto src_data_multi_id = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, repeat_multi_id, src_data_per_cluster_per_dims);
|
||||
|
||||
// reorder src_data_multi_id to get dst_data_multi_id
|
||||
constexpr auto dst_data_multi_id = src_data_multi_id.ReorderGivenNew2Old(MapDst2Src{});
|
||||
|
||||
constexpr index_t clipboard_offset =
|
||||
thread_tensor_desc.Get1dIndex(clipboard_data_multi_id);
|
||||
|
||||
constexpr index_t dst_offset = DstDesc{}.Get1dIndex(dst_data_multi_id);
|
||||
|
||||
// write in the order of dst
|
||||
#if 1
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
DstDesc{},
|
||||
p_dst + dst_offset +
|
||||
mDstMyThreadOffset,
|
||||
thread_sub_tensor_lengths,
|
||||
MapDst2Src{});
|
||||
#else
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v3(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
DstDesc{},
|
||||
p_dst + dst_offset +
|
||||
mDstMyThreadOffset,
|
||||
thread_sub_tensor_lengths,
|
||||
MapDst2Src{},
|
||||
Number<DstDataPerWrite>{});
|
||||
#endif
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
Float p_clipboard[GetRegisterClipboardSize()];
|
||||
|
||||
RunLoadRegisterClipboard(p_src, p_clipboard);
|
||||
RunStoreRegisterClipboard(p_clipboard, p_dst);
|
||||
}
|
||||
};
|
||||
@@ -73,7 +73,6 @@ __host__ __device__ constexpr auto get_convolution_with_padding_output_default_4
|
||||
template <class InDesc, class WeiDesc, class OutDesc>
|
||||
__host__ __device__ constexpr std::size_t calculate_convolution_flops(InDesc, WeiDesc, OutDesc)
|
||||
{
|
||||
constexpr auto in_desc = InDesc{};
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
constexpr auto out_desc = OutDesc{};
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
#include "config.h"
|
||||
#include "constant_integral.hip.hpp"
|
||||
|
||||
template <class T, index_t N>
|
||||
struct vector_type
|
||||
@@ -10,6 +11,13 @@ template <>
|
||||
struct vector_type<float, 1>
|
||||
{
|
||||
typedef float MemoryType;
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 1, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -20,21 +28,29 @@ struct vector_type<float, 2>
|
||||
// instruction
|
||||
typedef float MemoryType __attribute__((ext_vector_type(2)));
|
||||
#elif DEVICE_BACKEND_CUDA
|
||||
// For some reason, CUDA need this definition to, otherwise
|
||||
// For some reason, CUDA need this definition, otherwise
|
||||
// compiler won't generate optimal load and store instruction, and
|
||||
// kernel would produce wrong result, indicating the compiler fail to generate correct
|
||||
// instruction,
|
||||
using MemoryType = float2;
|
||||
#endif
|
||||
|
||||
union Data
|
||||
{
|
||||
MemoryType vector;
|
||||
float scalar[2];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 2, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
|
||||
__host__ __device__ static MemoryType Pack(float s0, float s1)
|
||||
{
|
||||
union
|
||||
{
|
||||
MemoryType vector;
|
||||
float scalar[2];
|
||||
} data;
|
||||
|
||||
Data data;
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
return data.vector;
|
||||
@@ -49,12 +65,19 @@ struct vector_type<float, 4>
|
||||
// instruction
|
||||
typedef float MemoryType __attribute__((ext_vector_type(4)));
|
||||
#elif DEVICE_BACKEND_CUDA
|
||||
// For some reason, CUDA need this definition to, otherwise
|
||||
// For some reason, CUDA need this definition, otherwise
|
||||
// compiler won't generate optimal load and store instruction, and
|
||||
// kernel would produce wrong result, indicating the compiler fail to generate correct
|
||||
// instruction,
|
||||
using MemoryType = float4;
|
||||
#endif
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 4, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_3d_tensor_op.hip.hpp"
|
||||
#include "blockwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
@@ -125,17 +125,17 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder =
|
||||
Blockwise4dTensorCopyReorder3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WiPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
BlockwiseNdTensorCopyReorder_v3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WiPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
@@ -133,17 +133,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder =
|
||||
Blockwise4dTensorCopyReorder3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
BlockwiseNdTensorCopyReorder_v3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, KPerBlock]
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
@@ -130,17 +130,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
|
||||
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder =
|
||||
Blockwise4dTensorCopyReorder3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
BlockwiseNdTensorCopyReorder_v3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, KPerBlock]
|
||||
|
||||
@@ -0,0 +1,514 @@
|
||||
#pragma once
|
||||
#include "common.hip.hpp"
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockReorderSrcSubLengths_NCHW,
|
||||
class InBlockReorderSrcClusterLengths_NCHW,
|
||||
class InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
index_t InBlockReorderDataPerRead_W,
|
||||
index_t InBlockReorderDataPerWrite_N,
|
||||
class WeiBlockCopyClusterLengths_CK, // not used
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_W>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
|
||||
constexpr index_t N = out_n_k_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
// constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
|
||||
const index_t w_block_work_id = itmp / NBlockWork;
|
||||
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_k_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = mod_conv::max(InBlockReorderDataPerWrite_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
|
||||
Number<InBlockReorderDataPerWrite_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
|
||||
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder =
|
||||
BlockwiseNdTensorCopyReorder_v3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc");
|
||||
|
||||
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
#if 1
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_n_c_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 1
|
||||
blockwise_in_copy_reorder.Run(p_in_global_block_offset +
|
||||
in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, x),
|
||||
p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset +
|
||||
wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, 0),
|
||||
p_wei_block);
|
||||
#else
|
||||
Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()];
|
||||
Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy_reorder.RunLoadRegisterClipboard(
|
||||
p_in_global_block_offset + in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, x),
|
||||
p_in_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(
|
||||
p_wei_global_block_offset + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, 0),
|
||||
p_wei_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block);
|
||||
|
||||
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block);
|
||||
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_n_c_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset +=
|
||||
CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
|
||||
p_wei_global_block_offset +=
|
||||
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
#if 0
|
||||
blockwise_in_copy_reorder.Run(p_in_global_block_offset,
|
||||
p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset,
|
||||
p_wei_block);
|
||||
#else
|
||||
Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()];
|
||||
Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block);
|
||||
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}(
|
||||
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) /
|
||||
f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
#if 0
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
#else
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2,
|
||||
K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2>{});
|
||||
#endif
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
#else
|
||||
constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 6, 3, 4, 5>{};
|
||||
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global + out_n_k_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
map_out_global2thread);
|
||||
// Number<OutThreadCopyDataPerWrite_W>{});
|
||||
#endif
|
||||
})
|
||||
.else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
#if 0
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2 * W3),
|
||||
W1,
|
||||
W2,
|
||||
W3,
|
||||
N / N1,
|
||||
N1>{});
|
||||
#else
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<N / N1,
|
||||
N1,
|
||||
K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2 * W3),
|
||||
W1,
|
||||
W2,
|
||||
W3>{});
|
||||
#endif
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
|
||||
for(index_t i = 0; i < 64; ++i)
|
||||
{
|
||||
printf("out %f, ", p_out_thread[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
#else
|
||||
constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{};
|
||||
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global + out_n_k_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
map_out_global2thread);
|
||||
// Number<OutThreadCopyDataPerWrite_W>{});
|
||||
#endif
|
||||
});
|
||||
}
|
||||
};
|
||||
@@ -139,135 +139,6 @@ __device__ void threadwise_4d_tensor_copy_reorder_given_dst2src(SrcDesc,
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
|
||||
#if 0 // replaced threadwise_nd_tensor_copy
|
||||
template <class SrcData, class DstData, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
__device__ void threadwise_4d_tensor_copy(
|
||||
SrcDesc, const SrcData* __restrict__ p_src, DstDesc, DstData* __restrict__ p_dst, SrcOpLengths)
|
||||
{
|
||||
auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{};
|
||||
|
||||
threadwise_4d_tensor_copy_reorder_given_dst2src(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
|
||||
}
|
||||
|
||||
// need to assume src and dst is aligned
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
|
||||
__device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
Number<DataPerRead>)
|
||||
{
|
||||
static_assert(SrcDesc{}.GetDimension() == 4 && DstDesc{}.GetDimension() == 4 &&
|
||||
SrcOpLengths::GetSize() == 4,
|
||||
"wrong! should be 4 dimension");
|
||||
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I3) == 1 && DstDesc{}.GetStride(I3) == 1,
|
||||
"wrong! only support stride3 == 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I2) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I2) % DataPerRead == 0,
|
||||
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr index_t L3 = SrcOpLengths{}.Get(I3);
|
||||
|
||||
static_assert(L3 % DataPerRead == 0, "wrong! L3 should be evenly divided by DataPerRead");
|
||||
|
||||
constexpr index_t nloop_d3 = L3 / DataPerRead;
|
||||
|
||||
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
|
||||
{
|
||||
const index_t src_index =
|
||||
src_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);
|
||||
|
||||
const index_t dst_index =
|
||||
dst_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_index])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_src[src_index]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class SrcData,
|
||||
class DstData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
threadwise_4d_tensor_copy_reorder_given_dst2src_v2(SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t IR0 = MapDst2Src{}.Get(I0);
|
||||
constexpr index_t IR1 = MapDst2Src{}.Get(I1);
|
||||
constexpr index_t IR2 = MapDst2Src{}.Get(I2);
|
||||
constexpr index_t IR3 = MapDst2Src{}.Get(I3);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
// ref_desc has dst_desc's ordering
|
||||
constexpr auto ref_desc =
|
||||
make_ConstantTensorDescriptor(SrcOpLengths{}.ReorderGivenNew2Old(MapDst2Src{}));
|
||||
|
||||
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
const auto dst_multi_id = Array<index_t, 4>{did0, did1, did2, did3};
|
||||
|
||||
const auto src_multi_id =
|
||||
reorder_array_given_old2new(dst_multi_id, MapDst2Src{});
|
||||
|
||||
const index_t dst_index = dst_desc.Get1dIndex(dst_multi_id);
|
||||
|
||||
const index_t src_index = src_desc.Get1dIndex(src_multi_id);
|
||||
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Float, class Desc, class IDim, class NShift>
|
||||
__device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
|
||||
{
|
||||
|
||||
@@ -50,7 +50,7 @@ __device__ void threadwise_nd_tensor_copy(SrcDesc,
|
||||
constexpr index_t nRead = L_Back / DataPerRead;
|
||||
|
||||
static_ford<decltype(ref_desc.GetLengths().PopBack())>{}([=](auto Ids) {
|
||||
static_for<0, nRead, 1>{}([=](auto IRead) {
|
||||
static_for<0, nRead, 1>{}([&](auto IRead) {
|
||||
constexpr auto multi_id = decltype(Ids){}.PushBack(Number<IRead.Get() * DataPerRead>{});
|
||||
|
||||
const index_t src_index = src_desc.Get1dIndex(multi_id);
|
||||
@@ -62,3 +62,131 @@ __device__ void threadwise_nd_tensor_copy(SrcDesc,
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// write in order of src
|
||||
template <class SrcData,
|
||||
class DstData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v1(SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
ford<SrcOpLengths>{}([&](auto src_multi_id) {
|
||||
const auto dst_multi_id = reorder_array_given_new2old(src_multi_id, MapDst2Src{});
|
||||
|
||||
const index_t dst_index = dst_desc.Get1dIndex(dst_multi_id);
|
||||
|
||||
const index_t src_index = src_desc.Get1dIndex(src_multi_id);
|
||||
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
});
|
||||
}
|
||||
|
||||
// write in order of dst
|
||||
template <class SrcData,
|
||||
class DstData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr auto dst_op_lengths = SrcOpLengths{}.ReorderGivenNew2Old(MapDst2Src{});
|
||||
|
||||
ford<decltype(dst_op_lengths)>{}([&](auto dst_multi_id) {
|
||||
const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{});
|
||||
|
||||
const index_t dst_index = dst_desc.Get1dIndex(dst_multi_id);
|
||||
|
||||
const index_t src_index = src_desc.Get1dIndex(src_multi_id);
|
||||
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
});
|
||||
}
|
||||
|
||||
// write in order of dst
|
||||
template <class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src,
|
||||
index_t DstDataPerWrite>
|
||||
__device__ void threadwise_nd_tensor_copy_reorder_given_dst2src_v3(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src,
|
||||
Number<DstDataPerWrite>)
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DstDataPerWrite>::MemoryType;
|
||||
|
||||
constexpr index_t nDim = SrcOpLengths::GetSize();
|
||||
|
||||
static_assert(DstDataPerWrite == 1 || DstDesc{}.GetStride(Number<nDim - 1>{}) == 1,
|
||||
"wrong! only support dst.stride[nDim-1] == 1, if DstDataPerWrite != 1");
|
||||
|
||||
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4,
|
||||
"wrong! only support DstDataPerWrite == 1, 2 or 4");
|
||||
|
||||
static_assert(
|
||||
DstDesc{}.GetStride(Number<nDim - 2>{}) % DstDataPerWrite == 0,
|
||||
"wrong! dst.stride[nDim-2] should be multiple of DstDataPerWrite to keep alignment");
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr auto dst_op_lengths = SrcOpLengths{}.ReorderGivenNew2Old(MapDst2Src{});
|
||||
|
||||
constexpr index_t L_Dst_Back = dst_op_lengths.Back();
|
||||
|
||||
static_assert(L_Dst_Back % DstDataPerWrite == 0,
|
||||
"wrong! dst.lengths[nDim-1] should be evenly divided by DstDataPerWrite");
|
||||
|
||||
constexpr index_t nWrite = L_Dst_Back / DstDataPerWrite;
|
||||
|
||||
ford<decltype(dst_op_lengths.PopBack())>{}([&](auto ids) {
|
||||
static_for<0, nWrite, 1>{}([&](auto IWrite) {
|
||||
vector_t dst_vec_data;
|
||||
|
||||
// pack data
|
||||
static_for<0, DstDataPerWrite, 1>{}([&](auto IDstData) {
|
||||
const auto dst_multi_id =
|
||||
ids.PushBack(IWrite.Get() * DstDataPerWrite + IDstData.Get());
|
||||
|
||||
const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{});
|
||||
|
||||
const index_t src_index = src_desc.Get1dIndex(src_multi_id);
|
||||
|
||||
vector_type<Float, DstDataPerWrite>::SetScalar(
|
||||
dst_vec_data, p_src[src_index], IDstData);
|
||||
});
|
||||
|
||||
// write data
|
||||
const auto dst_multi_id = ids.PushBack(IWrite.Get() * DstDataPerWrite);
|
||||
|
||||
const index_t dst_index = dst_desc.Get1dIndex(dst_multi_id);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_index])) = dst_vec_data;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user