reorginzed files

This commit is contained in:
Chao Liu
2019-06-13 15:12:12 -05:00
parent c82b833d8e
commit 1566b31736
64 changed files with 254 additions and 218 deletions

View File

@@ -0,0 +1,806 @@
#ifndef CK_BLOCKWISE_2D_TENSOR_OP_HPP
#define CK_BLOCKWISE_2D_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace ck {
template <index_t BlockSize, class Float, class DstDesc, class F>
__device__ void
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto dst_desc = DstDesc{};
constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths());
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
}
#endif
constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0);
const index_t did1 = is / desc.GetStride(I1);
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
f(p_dst[dindex]);
}
constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);
if(has_tail)
{
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < desc.GetElementSize())
{
const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0);
const index_t did1 = is / desc.GetStride(I1);
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
f(p_dst[dindex]);
}
}
}
// Function: p_dst[reorder[i0], reorder[i1] = p_src[i0,i1]
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class SrcOpLengths,
class MapDst2Src,
class F>
__device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
SrcDesc,
const Float* __restrict__ p_src,
DstDesc,
Float* __restrict__ p_dst,
SrcOpLengths,
MapDst2Src,
F f)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t IR0 = MapDst2Src{}.Get(I0);
constexpr index_t IR1 = MapDst2Src{}.Get(I1);
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
index_t did[2];
did[0] = is / ref_desc.GetStride(I0);
is -= did[0] * ref_desc.GetStride(I0);
did[1] = is / ref_desc.GetStride(I1);
const index_t aindex = src_desc.GetOffsetFromMultiIndex(did[0], did[1]);
const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1]);
f(p_src[aindex], p_dst[bindex]);
}
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
if(has_tail)
{
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
index_t did[2];
did[0] = is / ref_desc.GetStride(I0);
is -= did[0] * ref_desc.GetStride(I0);
did[1] = is / ref_desc.GetStride(I1);
const index_t aindex = src_desc.GetOffsetFromMultiIndex(did[0], did[1]);
const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1]);
f(p_src[aindex], p_dst[bindex]);
}
}
}
template <index_t BlockSize, class Float, class DstDesc>
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
{
auto f_set_zero = [](Float& v) { v = Float(0); };
blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
}
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class SrcOpLengths,
class MapDst2Src>
__device__ void
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
const Float* __restrict__ p_src,
DstDesc,
Float* __restrict__ p_dst,
SrcOpLengths,
MapDst2Src)
{
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
}
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
index_t DataPerRead>
struct Blockwise2dTensorCopy1
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
__device__ constexpr Blockwise2dTensorCopy1()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
static_assert(DataPerRead == 1 ||
(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
"wrong! only support stride1 == 1 if DataPerRead > 1!\n");
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
"wrong! only support DataPerRead == 1, 2 or 4!\n");
static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
DstDesc{}.GetStride(I0) % DataPerRead == 0,
"src and dst stride2 should be multiple of DataPerRead to keep alignment");
// we allow out-of-bound read from src in D1 dimension,
// but we need to make sure dst stride0 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t read_per_d1 = math::integer_divide_ceil(L1, DataPerRead);
static_assert(read_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
"wrong! out-of-bound write will contaminate next line!\n");
}
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t read_per_d1 = math::integer_divide_ceil(L1, DataPerRead);
constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, read_per_d1>{});
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
auto f_copy = [&](index_t is) {
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
is -= did[0] * ref_desc.GetStride(I0);
did[1] = is / ref_desc.GetStride(I1);
const index_t src_index =
src_desc.GetOffsetFromMultiIndex(did[0], did[1] * DataPerRead);
const index_t dst_index =
dst_desc.GetOffsetFromMultiIndex(did[0], did[1] * DataPerRead);
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src + src_index));
};
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
f_copy(is);
}
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
if(has_tail)
{
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
f_copy(is);
}
}
}
};
// need to be aligned to float4 and float2
// stride1 need to be 1 for both source and destination
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class SrcOpLengths,
index_t ThreadPerDim0,
index_t ThreadPerDim1>
struct Blockwise2dTensorCopy2
{
index_t mThreadId0;
index_t mThreadId1;
__device__ Blockwise2dTensorCopy2()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
static_assert(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1,
"wrong! stride is not 1!\n");
mThreadId0 = get_thread_local_1d_id() / ThreadPerDim1;
mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1;
}
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
using Float4 = float4;
using Float2 = float2;
if(get_thread_local_1d_id() >= ThreadPerDim0 * ThreadPerDim1)
return;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
// check alignment
constexpr bool align_v4 =
src_desc.GetStride(I0) % 4 == 0 && dst_desc.GetStride(I0) % 4 == 0;
constexpr bool align_v2 =
src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;
constexpr index_t L0 = SrcOpLengths{}.Get(I0);
constexpr index_t L1 = SrcOpLengths{}.Get(I1);
constexpr index_t Dim0Loop = L0 / ThreadPerDim0;
constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
constexpr index_t Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
constexpr index_t Dim1V2Loop =
align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;
constexpr index_t Dim1V1Loop =
(L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
ThreadPerDim1;
constexpr bool d1_has_tail =
(L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));
for(index_t d0loop = 0; d0loop < Dim0Loop; ++d0loop)
{
index_t did0 = d0loop * ThreadPerDim0 + mThreadId0;
// v4
for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
{
index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
*(reinterpret_cast<Float4*>(p_dst + dindex)) =
*(reinterpret_cast<const Float4*>(p_src + sindex));
}
// v2
for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
{
index_t did1 =
Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
*(reinterpret_cast<Float2*>(p_dst + dindex)) =
*(reinterpret_cast<const Float2*>(p_src + sindex));
}
// v1
for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
{
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
d1v1loop * ThreadPerDim1 + mThreadId1;
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
p_dst[dindex] = p_src[sindex];
}
// dim-1 tail
if(d1_has_tail)
{
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
if(did1 < L1)
{
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
p_dst[dindex] = p_src[sindex];
}
}
}
// dim-0 tail
if(d0_has_tail)
{
index_t did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
if(did0 < L0)
{
// v4
for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
{
index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
*(reinterpret_cast<Float4*>(p_dst + dindex)) =
*(reinterpret_cast<const Float4*>(p_src + sindex));
}
// v2
for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
{
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
2 * mThreadId1;
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
*(reinterpret_cast<Float2*>(p_dst + dindex)) =
*(reinterpret_cast<const Float2*>(p_src + sindex));
}
// v1
for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
{
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
d1v1loop * ThreadPerDim1 + mThreadId1;
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
p_dst[dindex] = p_src[sindex];
}
// tail
if(d1_has_tail)
{
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
if(did1 < L1)
{
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
p_dst[dindex] = p_src[sindex];
}
}
}
}
}
};
// starting point need to be aligned to float4 or float2 or float
// stride1 need to be 1 for both source and destination
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
index_t DataPerRead>
struct Blockwise2dTensorCopy3
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
index_t mSrcMyThreadOffset;
index_t mDstMyThreadOffset;
__device__ Blockwise2dTensorCopy3(Array<index_t, 2> src_block_data_multi_id_begin,
Array<index_t, 2> dst_block_data_multi_id_begin)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
static_assert(DataPerRead == 1 ||
(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
"wrong! only support stride1 == 1 if DataPerRead > 1!\n");
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
"wrong! only support DataPerRead == 1, 2 or 4!\n");
static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
DstDesc{}.GetStride(I0) % DataPerRead == 0,
"src and dst stride should be multiple of DataPerRead to keep alignment");
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
// we allow out-of-bound read from src in D1 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
static_assert(thread_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
"wrong! out-of-bound write will contaminate next line!\n");
static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n");
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
const index_t thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
const index_t thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
mSrcMyThreadOffset = SrcDesc{}.GetOffsetFromMultiIndex(
src_block_data_multi_id_begin +
Array<index_t, 2>{thread_id_d0, thread_id_d1 * DataPerRead});
mDstMyThreadOffset = DstDesc{}.GetOffsetFromMultiIndex(
dst_block_data_multi_id_begin +
Array<index_t, 2>{thread_id_d0, thread_id_d1 * DataPerRead});
}
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
auto f_copy = [&](index_t iloop) {
*(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
*(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
iloop * src_loop_stride));
};
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
{
f_copy(iloop);
}
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
if(has_tail_d0)
{
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
f_copy(nloop_d0);
}
}
}
__device__ constexpr index_t GetRegisterClipboardSize() const
{
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
}
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
Float* __restrict__ p_clipboard) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
auto f_copy = [&](index_t iloop) {
*(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
*(reinterpret_cast<const vector_t*>(
&p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
};
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
{
f_copy(iloop);
}
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
if(has_tail_d0)
{
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
f_copy(nloop_d0);
}
}
}
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
auto f_copy = [&](index_t iloop) {
*(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride])) =
*(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]));
};
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
{
f_copy(iloop);
}
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
if(has_tail_d0)
{
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
f_copy(nloop_d0);
}
}
}
#if CK_USE_AMD_INLINE_ASM
__device__ void RunLoadRegisterClipboard_asm(const Float* __restrict__ p_src,
Float* p_clipboard) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
auto f_copy = [&](index_t iloop) {
#if 0
*(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
*(reinterpret_cast<const vector_t*>(&p_src[mSrcMyThreadOffset +
iloop * src_loop_stride]));
#else
static_assert(is_same<float, Float>::value && DataPerRead == 4,
"global_load is only for float4");
global_load(reinterpret_cast<vector_t&>(p_clipboard[iloop * DataPerRead]),
reinterpret_cast<const vector_t*>(
&p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
#endif
};
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
{
f_copy(iloop);
}
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
if(has_tail_d0)
{
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
f_copy(nloop_d0);
}
}
}
__device__ void RunStoreRegisterClipboard_asm(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
auto f_copy = [&](index_t iloop) {
#if 0
*(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]) =
*(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]);
#else
static_assert(is_same<float, Float>::value && DataPerRead == 4,
"ds_write_b128 is only for float4");
ds_write_b128(reinterpret_cast<const vector_t&>(p_clipboard[iloop * DataPerRead]),
&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]);
#endif
};
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
{
f_copy(iloop);
}
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
if(has_tail_d0)
{
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
{
f_copy(nloop_d0);
}
}
}
#endif
};
} // namespace ck
#endif

View File

@@ -0,0 +1,378 @@
#ifndef CK_BLOCKWISE_3D_TENSOR_OP_HPP
#define CK_BLOCKWISE_3D_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace ck {
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
index_t DataPerRead>
struct Blockwise3dTensorCopy1
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
__device__ constexpr Blockwise3dTensorCopy1()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
static_assert(DataPerRead == 1 ||
(SrcDesc{}.GetStride(I2) == 1 && DstDesc{}.GetStride(I2) == 1),
"wrong! only support stride2 == 1 if DataPerRead > 1!\n");
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
"wrong! only support DataPerRead == 1, 2 or 4!\n");
static_assert(SrcDesc{}.GetStride(I1) % DataPerRead == 0 &&
DstDesc{}.GetStride(I1) % DataPerRead == 0,
"src and dst stride1 should be multiple of DataPerRead to keep alignment");
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride2 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t read_per_d2 = math::integer_divide_ceil(L2, DataPerRead);
static_assert(read_per_d2 * DataPerRead <= DstDesc{}.GetStride(I1),
"wrong! out-of-bound write will contaminate next line!\n");
}
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t read_per_d2 = math::integer_divide_ceil(L2, DataPerRead);
constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, L1, read_per_d2>{});
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
auto f_copy = [&](index_t is) {
index_t did[3];
did[0] = is / ref_desc.GetStride(I0);
is -= did[0] * ref_desc.GetStride(I0);
did[1] = is / ref_desc.GetStride(I1);
is -= did[1] * ref_desc.GetStride(I1);
did[2] = is / ref_desc.GetStride(I2);
const index_t src_index =
src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2] * DataPerRead);
const index_t dst_index =
dst_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2] * DataPerRead);
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src + src_index));
};
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
f_copy(is);
}
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
if(has_tail)
{
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
f_copy(is);
}
}
}
};
// starting point need to be aligned to float4 or float2 or float
// stride3 need to be 1 for both source and destination
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
class ThreadPerDims,
index_t DataPerRead>
struct Blockwise3dTensorCopy3
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
index_t mSrcMyThreadOffset;
index_t mDstMyThreadOffset;
__device__ Blockwise3dTensorCopy3()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
static_assert(DataPerRead == 1 ||
(SrcDesc{}.GetStride(I2) == 1 && DstDesc{}.GetStride(I2) == 1),
"wrong! only support stride3 == 1 if DataPerRead > 1!\n");
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
"wrong! only support DataPerRead == 1, 2 or 4!\n");
static_assert(
SrcDesc{}.GetStride(I1) % DataPerRead == 0 &&
DstDesc{}.GetStride(I1) % DataPerRead == 0,
"wrong! src and dst stride1 should be multiple of DataPerRead to keep alignment");
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
// we allow out-of-bound read from src in D2 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
static_assert(nloop_d2 * thread_per_d2 * DataPerRead <= DstDesc{}.GetStride(I1),
"wrong! out-of-bound write will contaminate next line!\n");
static_assert(L0 % thread_per_d0 == 0 && L1 % thread_per_d1 == 0,
"wrong! L0, L1, L2 should be divided evenly!\n");
static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2,
"wrrong! BlockSize is not big enough for ThreadPerDims!");
constexpr index_t num_active_thread =
accumulate_on_sequence(ThreadPerDims{}, math::multiplies<index_t>{}, Number<1>{});
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor(ThreadPerDims{});
const auto thread_multi_id =
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
mSrcMyThreadOffset = SrcDesc{}.GetOffsetFromMultiIndex(
thread_multi_id[0], thread_multi_id[1], thread_multi_id[2] * DataPerRead);
mDstMyThreadOffset = DstDesc{}.GetOffsetFromMultiIndex(
thread_multi_id[0], thread_multi_id[1], thread_multi_id[2] * DataPerRead);
}
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
#pragma unroll
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
{
#pragma unroll
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
{
#pragma unroll
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
{
const index_t src_offset =
SrcDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2 * DataPerRead);
const index_t dst_offset =
DstDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2 * DataPerRead);
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) = *(
reinterpret_cast<const vector_t*>(&p_src[src_offset + mSrcMyThreadOffset]));
}
}
}
}
__device__ static constexpr index_t GetRegisterClipboardSize()
{
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2;
}
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
Float* __restrict__ p_clipboard) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
constexpr auto clipboard_desc =
make_ConstantTensorDescriptor(Sequence<nloop_d0, nloop_d1, nloop_d2 * DataPerRead>{});
#pragma unroll
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
{
#pragma unroll
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
{
#pragma unroll
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
{
const index_t src_offset =
SrcDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2 * DataPerRead);
const index_t clipboard_offset = clipboard_desc.GetOffsetFromMultiIndex(
iloop_d0, iloop_d1, iloop_d2 * DataPerRead);
*(reinterpret_cast<vector_t*>(&p_clipboard[clipboard_offset])) = *(
reinterpret_cast<const vector_t*>(&p_src[src_offset + mSrcMyThreadOffset]));
}
}
}
}
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
constexpr auto clipboard_desc =
make_ConstantTensorDescriptor(Sequence<nloop_d0, nloop_d1, nloop_d2 * DataPerRead>{});
#pragma unroll
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
{
#pragma unroll
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
{
#pragma unroll
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
{
const index_t clipboard_offset = clipboard_desc.GetOffsetFromMultiIndex(
iloop_d0, iloop_d1, iloop_d2 * DataPerRead);
const index_t dst_offset =
DstDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2 * DataPerRead);
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) =
*(reinterpret_cast<const vector_t*>(&p_clipboard[clipboard_offset]));
}
}
}
}
};
} // namespace ck
#endif

View File

@@ -0,0 +1,779 @@
#ifndef CK_BLOCKWISE_4D_TENSOR_OP_HPP
#define CK_BLOCKWISE_4D_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "threadwise_tensor_slice_copy.hpp"
namespace ck {
template <index_t BlockSize, class Float, class DstDesc, class F>
__device__ void
blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto dst_desc = DstDesc{};
constexpr auto desc = make_ConstantTensorDescriptor_packed(dst_desc.GetLengths());
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
}
#endif
constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0);
const index_t did1 = is / desc.GetStride(I1);
is -= did1 * desc.GetStride(I1);
const index_t did2 = is / desc.GetStride(I2);
is -= did2 * desc.GetStride(I2);
const index_t did3 = is / desc.GetStride(I3);
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1, did2, did3);
f(p_dst[dindex]);
}
constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);
if(has_tail)
{
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < desc.GetElementSize())
{
const index_t did0 = is / desc.GetStride(I0);
is -= did0 * desc.GetStride(I0);
const index_t did1 = is / desc.GetStride(I1);
is -= did1 * desc.GetStride(I1);
const index_t did2 = is / desc.GetStride(I2);
is -= did2 * desc.GetStride(I2);
const index_t did3 = is / desc.GetStride(I3);
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1, did2, did3);
f(p_dst[dindex]);
}
}
}
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class SrcOpLengths,
class MapDst2Src,
class F>
__device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
SrcDesc,
const Float* __restrict__ p_src,
DstDesc,
Float* __restrict__ p_dst,
SrcOpLengths,
MapDst2Src,
F f)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr index_t IR0 = MapDst2Src{}.Get(I0);
constexpr index_t IR1 = MapDst2Src{}.Get(I1);
constexpr index_t IR2 = MapDst2Src{}.Get(I2);
constexpr index_t IR3 = MapDst2Src{}.Get(I3);
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor_packed(SrcOpLengths{});
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
is -= did[0] * ref_desc.GetStride(I0);
did[1] = is / ref_desc.GetStride(I1);
is -= did[1] * ref_desc.GetStride(I1);
did[2] = is / ref_desc.GetStride(I2);
is -= did[2] * ref_desc.GetStride(I2);
did[3] = is / ref_desc.GetStride(I3);
const index_t src_index = src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3]);
const index_t dst_index =
dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
f(p_src[src_index], p_dst[dst_index]);
}
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
if(has_tail)
{
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
is -= did[0] * ref_desc.GetStride(I0);
did[1] = is / ref_desc.GetStride(I1);
is -= did[1] * ref_desc.GetStride(I1);
did[2] = is / ref_desc.GetStride(I2);
is -= did[2] * ref_desc.GetStride(I2);
did[3] = is / ref_desc.GetStride(I3);
const index_t src_index =
src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3]);
const index_t dst_index =
dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
f(p_src[src_index], p_dst[dst_index]);
}
}
}
template <index_t BlockSize, class Float, class DstDesc>
__device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
{
auto f_set_zero = [](Float& v) { v = Float(0); };
blockwise_4d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
}
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class SrcOpLengths,
class MapDst2Src>
__device__ void
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
const Float* __restrict__ p_src,
DstDesc,
Float* __restrict__ p_dst,
SrcOpLengths,
MapDst2Src)
{
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
}
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
index_t DataPerRead>
struct Blockwise4dTensorCopy1
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
__device__ constexpr Blockwise4dTensorCopy1()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
static_assert(DataPerRead == 1 ||
(SrcDesc{}.GetStride(I3) == 1 && DstDesc{}.GetStride(I3) == 1),
"wrong! only support stride3 == 1 if DataPerRead > 1!\n");
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
"wrong! only support DataPerRead == 1, 2 or 4!\n");
static_assert(SrcDesc{}.GetStride(I2) % DataPerRead == 0 &&
DstDesc{}.GetStride(I2) % DataPerRead == 0,
"src and dst stride2 should be multiple of DataPerRead to keep alignment");
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride2 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr index_t read_per_d3 = math::integer_divide_ceil(L3, DataPerRead);
static_assert(read_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
"wrong! out-of-bound write will contaminate next line!\n");
}
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr index_t read_per_d3 = math::integer_divide_ceil(L3, DataPerRead);
constexpr auto ref_desc =
make_ConstantTensorDescriptor_packed(Sequence<L0, L1, L2, read_per_d3>{});
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
auto f_copy = [&](index_t is) {
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
is -= did[0] * ref_desc.GetStride(I0);
did[1] = is / ref_desc.GetStride(I1);
is -= did[1] * ref_desc.GetStride(I1);
did[2] = is / ref_desc.GetStride(I2);
is -= did[2] * ref_desc.GetStride(I2);
did[3] = is / ref_desc.GetStride(I3);
const index_t src_index =
src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3] * DataPerRead);
const index_t dst_index =
dst_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3] * DataPerRead);
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src + src_index));
};
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
f_copy(is);
}
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
if(has_tail)
{
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
f_copy(is);
}
}
}
};
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class DstOpLengths,
class GlobalLowerPads>
struct BlockwiseChwnTensorCopyPadded
{
__device__ void Run(const Float* __restrict__ p_src,
index_t c_block_data_begin,
index_t ho_block_data_begin,
index_t wo_block_data_begin,
index_t n_block_data_begin,
Float* __restrict__ p_dst,
index_t h_block_pad_low,
index_t w_block_pad_low,
index_t h_block_pad_up,
index_t w_block_pad_up) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor_packed(DstOpLengths{});
constexpr auto h_global_pad_low = GlobalLowerPads{}.Get(I0);
constexpr auto w_global_pad_low = GlobalLowerPads{}.Get(I1);
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
const Float* p_src_tmp = p_src +
src_desc.GetOffsetFromMultiIndex(
c_block_data_begin,
(ho_block_data_begin + h_block_pad_low) - h_global_pad_low,
(wo_block_data_begin + w_block_pad_low) - w_global_pad_low,
n_block_data_begin);
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(src_desc, "src_desc: ");
print_ConstantTensorDescriptor(dst_desc, "dst_desc: ");
print_ConstantTensorDescriptor(ref_desc, "ref_desc: ");
printf("%u %u, \t"
"h_global_pad_low %u w_global_pad_low %u \t"
"h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u \t"
"\n",
get_block_1d_id(),
get_thread_local_1d_id(),
h_global_pad_low,
w_global_pad_low,
h_block_pad_low,
w_block_pad_low,
h_block_pad_up,
w_block_pad_up);
}
#endif
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
is -= did[0] * ref_desc.GetStride(I0);
did[1] = is / ref_desc.GetStride(I1);
is -= did[1] * ref_desc.GetStride(I1);
did[2] = is / ref_desc.GetStride(I2);
is -= did[2] * ref_desc.GetStride(I2);
did[3] = is / ref_desc.GetStride(I3);
const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3]);
p_dst[bindex] =
(did[1] < h_block_pad_low || did[1] + h_block_pad_up >= ref_desc.GetLength(I1) ||
did[2] < w_block_pad_low || did[2] + w_block_pad_up >= ref_desc.GetLength(I2))
? Float(0)
: p_src_tmp[src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3])];
}
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
if(has_tail)
{
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
is -= did[0] * ref_desc.GetStride(I0);
did[1] = is / ref_desc.GetStride(I1);
is -= did[1] * ref_desc.GetStride(I1);
did[2] = is / ref_desc.GetStride(I2);
is -= did[2] * ref_desc.GetStride(I2);
did[3] = is / ref_desc.GetStride(I3);
const index_t bindex =
dst_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3]);
p_dst[bindex] =
(did[1] < h_block_pad_low ||
did[1] + h_block_pad_up >= ref_desc.GetLength(I1) ||
did[2] < w_block_pad_low || did[2] + w_block_pad_up >= ref_desc.GetLength(I2))
? Float(0)
: p_src_tmp[src_desc.GetOffsetFromMultiIndex(
did[0], did[1], did[2], did[3])];
}
}
}
};
// starting point need to be aligned to float4 or float2 or float
// stride3 need to be 1 for both source and destination
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
class ThreadPerDims,
index_t DataPerRead>
struct Blockwise4dTensorCopy3
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
index_t mSrcMyThreadOffset;
index_t mDstMyThreadOffset;
__device__ Blockwise4dTensorCopy3()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
static_assert(DataPerRead == 1 ||
(SrcDesc{}.GetStride(I3) == 1 && DstDesc{}.GetStride(I3) == 1),
"wrong! only support stride3 == 1 if DataPerRead > 1!\n");
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
"wrong! only support DataPerRead == 1, 2 or 4!\n");
static_assert(
SrcDesc{}.GetStride(I2) % DataPerRead == 0 &&
DstDesc{}.GetStride(I2) % DataPerRead == 0,
"wrong! src and dst stride2 should be multiple of DataPerRead to keep alignment");
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
static_assert(nloop_d3 * thread_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
"wrong! out-of-bound write will contaminate next line!\n");
static_assert(L0 % thread_per_d0 == 0 && L1 % thread_per_d1 == 0 && L2 % thread_per_d2 == 0,
"wrong! L0, L1, L2 should be divided evenly!\n");
static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3,
"wrrong! BlockSize is not big enough for ThreadPerDims!");
constexpr index_t num_active_thread =
accumulate_on_sequence(ThreadPerDims{}, math::multiplies<index_t>{}, Number<1>{});
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(ThreadPerDims{});
const auto thread_multi_id =
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
mSrcMyThreadOffset = SrcDesc{}.GetOffsetFromMultiIndex(thread_multi_id[0],
thread_multi_id[1],
thread_multi_id[2],
thread_multi_id[3] * DataPerRead);
mDstMyThreadOffset = DstDesc{}.GetOffsetFromMultiIndex(thread_multi_id[0],
thread_multi_id[1],
thread_multi_id[2],
thread_multi_id[3] * DataPerRead);
}
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
constexpr index_t num_active_thread =
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = L2 / thread_per_d2;
constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
#pragma unroll
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
{
#pragma unroll
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
{
#pragma unroll
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
{
#pragma unroll
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
{
const index_t src_offset = SrcDesc{}.GetOffsetFromMultiIndex(
iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2,
iloop_d3 * thread_per_d3 * DataPerRead);
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(
iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2,
iloop_d3 * thread_per_d3 * DataPerRead);
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) =
*(reinterpret_cast<const vector_t*>(
&p_src[src_offset + mSrcMyThreadOffset]));
}
}
}
}
}
__device__ constexpr index_t GetRegisterClipboardSize() const
{
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = L2 / thread_per_d2;
constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2 * nloop_d3;
}
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
Float* __restrict__ p_clipboard) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
constexpr index_t num_active_thread =
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = L2 / thread_per_d2;
constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
constexpr auto clipboard_desc = make_ConstantTensorDescriptor_packed(
Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{});
#pragma unroll
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
{
#pragma unroll
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
{
#pragma unroll
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
{
#pragma unroll
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
{
const index_t src_offset = SrcDesc{}.GetOffsetFromMultiIndex(
iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2,
iloop_d3 * thread_per_d3 * DataPerRead);
const index_t clipboard_offset = clipboard_desc.GetOffsetFromMultiIndex(
iloop_d0, iloop_d1, iloop_d2, iloop_d3 * DataPerRead);
*(reinterpret_cast<vector_t*>(&p_clipboard[clipboard_offset])) =
*(reinterpret_cast<const vector_t*>(
&p_src[src_offset + mSrcMyThreadOffset]));
}
}
}
}
}
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t L3 = CopyLengths{}.Get(I3);
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
constexpr index_t num_active_thread =
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
constexpr index_t nloop_d0 = L0 / thread_per_d0;
constexpr index_t nloop_d1 = L1 / thread_per_d1;
constexpr index_t nloop_d2 = L2 / thread_per_d2;
constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
constexpr auto clipboard_desc = make_ConstantTensorDescriptor_packed(
Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{});
#pragma unroll
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
{
#pragma unroll
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
{
#pragma unroll
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
{
#pragma unroll
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
{
const index_t clipboard_offset = clipboard_desc.GetOffsetFromMultiIndex(
iloop_d0, iloop_d1, iloop_d2, iloop_d3 * DataPerRead);
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(
iloop_d0 * thread_per_d0,
iloop_d1 * thread_per_d1,
iloop_d2 * thread_per_d2,
iloop_d3 * thread_per_d3 * DataPerRead);
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) =
*(reinterpret_cast<const vector_t*>(&p_clipboard[clipboard_offset]));
}
}
}
}
}
};
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class SrcOpLengths,
class MapDst2Src>
struct Blockwise4dTensorCopyReorder1
{
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
}
};
} // namespace
#endif

View File

@@ -0,0 +1,529 @@
#ifndef CK_BLOCKWISE_BATCHED_GEMM_HPP
#define CK_BLOCKWISE_BATCHED_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "threadwise_gemm.hpp"
namespace ck {
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
index_t BlockMatrixStrideA,
index_t BlockMatrixStrideB,
index_t ThreadMatrixStrideC,
index_t BatchSize,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop,
index_t BatchPerThread,
index_t DataPerReadA,
index_t DataPerReadB>
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
index_t mMyThreadOffsetA = 0;
index_t mMyThreadOffsetB = 0;
struct MatrixIndex
{
index_t batch;
index_t row;
index_t col;
};
__device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2()
{
static_assert(BatchSize % BatchPerThread == 0,
"wrong! BatchSize is not dividable by BatchPerThread");
constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster,
"wrong! wrong blocksize\n");
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! K dimension not consistent\n");
constexpr index_t M = a_block_mtx.NCol(); // A is transposed
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
"wrong! Cannot evenly divide thread work among repeat \n");
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
"wrong! Cannot evenly divide work among repeat\n");
constexpr index_t MPerLevel1Cluster = M / MRepeat;
constexpr index_t NPerLevel1Cluster = N / NRepeat;
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
(NPerLevel1Cluster % NLevel1Cluster == 0),
"wrong! Cannot evenly divide work among Level1Cluster\n");
constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
(NPerLevel0Cluster % NLevel0Cluster == 0),
"wrong! Cannot evenly divide work among Level0Cluster\n");
static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
(NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
"wrong! thread work size is wrong\n");
const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA +
a_block_mtx.GetOffsetFromMultiIndex(0, c_thread_mtx_index.row);
mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB +
b_block_mtx.GetOffsetFromMultiIndex(0, c_thread_mtx_index.col);
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch,
c_thread_mtx_index.row,
c_thread_mtx_index.col,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
{
constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
index_t batch_work_id = thread_id / ThreadPerLevel1Cluster;
index_t cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster;
index_t level1_id = cluster_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1Cluster;
index_t level1_n_id = level1_id % NLevel1Cluster;
index_t level0_id = cluster_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0Cluster;
index_t level0_n_id = level0_id % NLevel0Cluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
return MatrixIndex{batch_work_id * BatchPerThread,
level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
// this should be optimized away because input will be known at compile time
__device__ static MatrixIndex
GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
{
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
index_t m_repeat = m_in_c / MPerThreadSubC;
index_t n_repeat = n_in_c / NPerThreadSubC;
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{batch_in_c,
m_repeat * MPerLevel1Cluster + m_in_sub_c,
n_repeat * NPerLevel1Cluster + n_in_sub_c};
}
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// loop over k
#pragma unroll
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
// loop over batch
#pragma unroll
for(index_t ib = 0; ib < BatchPerThread; ++ib)
{
// read next batch of a, b
if(BlockMatrixStrideA != 0 or ib == 0)
{
#pragma unroll
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
p_a_block +
a_block_mtx.GetOffsetFromMultiIndex(k_begin,
m_repeat * MPerLevel1Cluster) +
ib * BlockMatrixStrideA + mMyThreadOffsetA,
a_thread_mtx,
p_a_thread +
a_thread_mtx.GetOffsetFromMultiIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
}
if(BlockMatrixStrideB != 0 or ib == 0)
{
#pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
p_b_block +
b_block_mtx.GetOffsetFromMultiIndex(k_begin,
n_repeat * NPerLevel1Cluster) +
ib * BlockMatrixStrideB + mMyThreadOffsetB,
b_thread_mtx,
p_b_thread +
b_thread_mtx.GetOffsetFromMultiIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
}
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("a: %f %f %f %f %f %f %f %f, b: %f %f %f %f %f %f %f %f\n",
p_a_thread[0],
p_a_thread[1],
p_a_thread[2],
p_a_thread[3],
p_a_thread[4],
p_a_thread[5],
p_a_thread[6],
p_a_thread[7],
p_b_thread[0],
p_b_thread[1],
p_b_thread[2],
p_b_thread[3],
p_b_thread[4],
p_b_thread[5],
p_b_thread[6],
p_b_thread[7]);
}
#endif
threadwise_gemm(a_thread_mtx,
True,
p_a_thread,
b_thread_mtx,
False,
p_b_thread,
c_thread_mtx,
False,
p_c_thread + ib * ThreadMatrixStrideC);
}
}
}
#if CK_USE_AMD_INLINE_ASM
template <class FloatA, class FloatB, class FloatC>
__device__ void Run_asm(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t K = a_block_mtx.NRow(); // A is transposed
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
// assertion for inline asm
static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
is_same<FloatC, float>::value,
"Run_asm only deal with float\n");
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
MPerThread == 8 && NPerThread == 8,
"Run_asm cannot deal with this GEMM shape yet\n");
static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");
static_assert(
BlockMatrixStrideA == 0 && BatchPerThread == 1,
"Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n");
using Float4 = vector_type<float, 4>::MemoryType;
Float4* reg_a = (Float4*)(p_a_thread);
Float4* reg_b = (Float4*)(p_b_thread);
Float4* reg_c = (Float4*)(p_c_thread);
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] = *reinterpret_cast<const Float4*>(
&p_b_block[b_block_mtx.GetOffsetFromMultiIndex(0, NPerLevel1Cluster) +
mMyThreadOffsetB]);
reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[a_block_mtx.GetOffsetFromMultiIndex(0, MPerLevel1Cluster) +
mMyThreadOffsetA]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
#pragma unroll
for(index_t k = 1; k < K; ++k)
{
reg_a[0] = *reinterpret_cast<const Float4*>(
&p_a_block[a_block_mtx.GetOffsetFromMultiIndex(k, 0) + mMyThreadOffsetA]);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
reg_b[0] = *reinterpret_cast<const Float4*>(
&p_b_block[b_block_mtx.GetOffsetFromMultiIndex(k, 0) + mMyThreadOffsetB]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
reg_b[1] = *reinterpret_cast<const Float4*>(
&p_b_block[b_block_mtx.GetOffsetFromMultiIndex(k, NPerLevel1Cluster) +
mMyThreadOffsetB]);
reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[a_block_mtx.GetOffsetFromMultiIndex(k, MPerLevel1Cluster) +
mMyThreadOffsetA]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
}
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
}
template <class FloatA, class FloatB, class FloatC>
__device__ void Run_asm_v2(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow(); // A is transposed
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
// assertion for inline asm
static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
is_same<FloatC, float>::value,
"Run_asm only deal with float\n");
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
MPerThread == 8 && NPerThread == 8,
"Run_asm cannot deal with this GEMM shape yet\n");
static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");
static_assert(
BlockMatrixStrideA == 0 && BatchPerThread == 1,
"Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n");
using Float4 = vector_type<float, 4>::MemoryType;
Float4* reg_a = (Float4*)(p_a_thread);
Float4* reg_b = (Float4*)(p_b_thread);
Float4* reg_c = (Float4*)(p_c_thread);
void* a_lds_loc = (void*)(p_a_block + mMyThreadOffsetA);
void* b_lds_loc = (void*)(p_b_block + mMyThreadOffsetB);
constexpr index_t a_lds_row_stride = sizeof(float) * a_block_mtx.RowStride();
constexpr index_t b_lds_row_stride = sizeof(float) * b_block_mtx.RowStride();
constexpr index_t a_lds_cluster_col_stride = sizeof(float) * MPerLevel1Cluster;
constexpr index_t b_lds_cluster_col_stride = sizeof(float) * NPerLevel1Cluster;
ds_read_b128(reg_a[0], a_lds_loc, 0);
ds_read_b128(reg_b[0], b_lds_loc, 0);
ds_read_b128(reg_b[1], b_lds_loc, b_lds_cluster_col_stride);
ds_read_b128(reg_a[1], a_lds_loc, a_lds_cluster_col_stride);
lgkmcnt(2);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
#pragma unroll
for(index_t k = 1; k < K; ++k)
{
ds_read_b128(reg_a[0], a_lds_loc, k * a_lds_row_stride);
lgkmcnt(1);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
ds_read_b128(reg_b[0], b_lds_loc, k * b_lds_row_stride);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
ds_read_b128(reg_b[1], b_lds_loc, b_lds_cluster_col_stride + k * b_lds_row_stride);
ds_read_b128(reg_a[1], a_lds_loc, a_lds_cluster_col_stride + k * a_lds_row_stride);
lgkmcnt(2);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
}
lgkmcnt(0);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
}
#endif
template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
__device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
FloatC* __restrict__ p_c_block) const
{
constexpr auto c_block_mtx = BlockMatrixC{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t c_thread_offset =
c_thread_mtx_begin.batch * BlockMatrixStrideC +
c_block_mtx.GetOffsetFromMultiIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col);
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
c_thread_sub_mtx,
p_c_thread +
c_thread_sub_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster,
n_repeat * NPerLevel1Cluster),
c_block_mtx,
p_c_block +
c_block_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster,
n_repeat * NPerLevel1Cluster) +
c_thread_offset,
c_thread_sub_mtx.GetLengths());
}
}
}
};
} // namespace
#endif

View File

@@ -0,0 +1,433 @@
#ifndef CK_BLOCKWISE_GEMM_HPP
#define CK_BLOCKWISE_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "threadwise_gemm.hpp"
namespace ck {
// if following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
template <index_t BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop,
index_t DataPerReadA,
index_t DataPerReadB>
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
struct MatrixIndex
{
index_t row;
index_t col;
};
index_t mMyThreadOffsetA;
index_t mMyThreadOffsetB;
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
{
constexpr index_t ThreadPerLevel1Cluster =
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(),
"wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
constexpr index_t K = BlockMatrixA::NRow();
static_assert(M % (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster) == 0 &&
N % (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong! Cannot evenly divide work among\n");
static_assert(is_same_type(ThreadMatrixC::GetLengths(), GetThreadMatrixCLengths()),
"wrong! ThreadMatrixC lengths is wrong");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = BlockMatrixA::GetOffsetFromMultiIndex(0, c_thread_mtx_index.row);
mMyThreadOffsetB = BlockMatrixB::GetOffsetFromMultiIndex(0, c_thread_mtx_index.col);
}
__device__ static constexpr auto GetThreadMatrixCLengths()
{
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
constexpr index_t N = BlockMatrixB::NCol();
constexpr index_t MRepeat = M / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster);
constexpr index_t NRepeat = N / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster);
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1Cluster;
index_t level1_n_id = level1_id % NLevel1Cluster;
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0Cluster;
index_t level0_n_id = level0_id % NLevel0Cluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
index_t n_in_c)
{
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
index_t m_repeat = m_in_c / MPerThreadSubC;
index_t n_repeat = n_in_c / NPerThreadSubC;
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
n_repeat * NPerLevel1Cluster + n_in_sub_c};
}
#if CK_USE_AMD_INLINE_ASM
// TODO: this is not working correctly
template <class FloatA, class FloatB, class FloatC>
__device__ void Run_asm(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
// assertion for inline asm
static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
is_same<FloatC, float>::value,
"Run_asm only deal with float\n");
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
MPerThread == 8 && NPerThread == 8,
"Run_asm cannot deal with this GEMM shape yet\n");
static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");
using Float4 = vector_type<float, 4>::MemoryType;
Float4* reg_a = (Float4*)(p_a_thread);
Float4* reg_b = (Float4*)(p_b_thread);
Float4* reg_c = (Float4*)(p_c_thread);
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
reg_b[1] =
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
reg_a[1] =
*reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
#pragma unroll
for(index_t k = 1; k < K; ++k)
{
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
reg_b[1] = *reinterpret_cast<const Float4*>(
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
reg_a[1] = *reinterpret_cast<const Float4*>(
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
}
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
}
#endif
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA* const __restrict__ p_a_block,
const FloatB* const __restrict__ p_b_block,
FloatC* const __restrict__ p_c_thread) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
const FloatA* const p_a_block_thread_offset = p_a_block + mMyThreadOffsetA;
#pragma unroll
// loop over k
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
#pragma unroll
// copy A-sub to form A
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
threadwise_matrix_copy(
a_block_mtx,
p_a_block +
a_block_mtx.GetOffsetFromMultiIndex(k_begin, m_repeat * MPerLevel1Cluster) +
mMyThreadOffsetA,
a_thread_mtx,
p_a_thread + a_thread_mtx.GetOffsetFromMultiIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
#pragma unroll
// copy B-sub to form B
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
threadwise_matrix_copy(
b_block_mtx,
p_b_block +
b_block_mtx.GetOffsetFromMultiIndex(k_begin, n_repeat * NPerLevel1Cluster) +
mMyThreadOffsetB,
b_thread_mtx,
p_b_thread + b_thread_mtx.GetOffsetFromMultiIndex(0, n_repeat * NPerThreadSubC),
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
// C = A * B
threadwise_gemm(a_thread_mtx,
True,
p_a_thread,
b_thread_mtx,
False,
p_b_thread,
c_thread_mtx,
False,
p_c_thread);
}
}
template <class FloatA, class FloatB, class FloatC>
__device__ void Run_RegisterDoubleBuffer(FloatA* const p_a_block,
FloatB* const p_b_block,
FloatC* p_c_thread) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow();
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
// register
FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()];
FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// preload A, B
#pragma unroll
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx,
p_a_thread_0 + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
#pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx,
p_b_thread_0 + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
bool even_loop = true;
#pragma unroll
for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
k_begin += KPerThreadLoop, even_loop = !even_loop)
{ // loop over k
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;
FloatA* p_a_thread_next = even_loop ? p_a_thread_1 : p_a_thread_0;
FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0;
// preload next A, B
#pragma unroll
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ // copy A-sub to form A
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA +
(k_begin + 1) * a_block_mtx.RowStride() +
m_repeat * MPerLevel1Cluster,
a_thread_sub_mtx,
p_a_thread_next + m_repeat * MPerThreadSubC,
a_thread_sub_mtx.GetLengths(),
Number<DataPerReadA>{});
}
#pragma unroll
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ // copy B-sub to form B
threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB +
(k_begin + 1) * b_block_mtx.RowStride() +
n_repeat * NPerLevel1Cluster,
b_thread_sub_mtx,
p_b_thread_next + n_repeat * NPerThreadSubC,
b_thread_sub_mtx.GetLengths(),
Number<DataPerReadB>{});
}
// C = A * B
threadwise_gemm(a_thread_mtx,
True,
p_a_thread_now,
b_thread_mtx,
False,
p_b_thread_now,
c_thread_mtx,
False,
p_c_thread);
}
// last loop
{
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;
// C = A * B
threadwise_gemm(a_thread_mtx,
True,
p_a_thread_now,
b_thread_mtx,
False,
p_b_thread_now,
c_thread_mtx,
False,
p_c_thread);
}
}
};
} // namespace ck
#endif

View File

@@ -0,0 +1,401 @@
#ifndef CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#define CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
namespace ck {
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
// memory layout (ordering of dimensions) can be different between src and dst
// For now, only support SubLengths[...] == 1 on a merged dimension
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class SliceLengths,
class SubLengths,
class DataClusterLengths,
class ThreadClusterArrangeOrder,
class SrcAccessOrder,
class DstAccessOrder,
index_t SrcDataPerRead,
index_t DstDataPerWrite>
struct BlockwiseGenericTensorSliceCopy_v1
{
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
static constexpr index_t nOriginalDimSrc =
SrcDesc::GetOriginalTensorDescriptor().GetNumOfDimension();
static constexpr index_t nOriginalDimDst =
DstDesc::GetOriginalTensorDescriptor().GetNumOfDimension();
// per-thread offset
index_t mThreadSrcOffset;
index_t mThreadDstOffset;
// "mThreadSrcOriginalMultiId", "mThreadSrcPartialOffsets, "mThreadDstOriginalMultiId",
// "mThreadDstPartialOffsets" are always calculated inside constructor, and would be
// updated if slicing-window is moved. However, they will not be used if you always move
// the slicing-window along a non-merged dimension. In that case, compiler should be
// able to remove these calculation.
// TODO: make sure compiler would actually remove them in that case
// partial offset in each (merged) dimension
Array<index_t, nDim> mThreadSrcPartialOffsets;
Array<index_t, nDim> mThreadDstPartialOffsets;
// multi-id of original tensor
Array<index_t, nOriginalDimSrc> mThreadSrcOriginalMultiId;
Array<index_t, nOriginalDimDst> mThreadDstOriginalMultiId;
__device__
BlockwiseGenericTensorSliceCopy_v1(Array<index_t, nDim> src_block_data_multi_id_begin,
Array<index_t, nDim> dst_block_data_multi_id_begin)
{
// check NDim consistency
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
nDim == SubLengths::GetSize() && nDim == DataClusterLengths::GetSize() &&
nDim == ThreadClusterArrangeOrder::GetSize() &&
nDim == SrcAccessOrder::GetSize() && nDim == DstAccessOrder::GetSize(),
"wrong");
// check thread arrange order and read/write access order are valid
static_assert(is_valid_sequence_map<ThreadClusterArrangeOrder>::value &&
is_valid_sequence_map<SrcAccessOrder>::value &&
is_valid_sequence_map<DstAccessOrder>::value,
"wrong!");
// thread cluster
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
DataClusterLengths{}.ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
// BlockSize
static_assert(BlockSize == thread_cluster_desc.GetElementSize(), "wrong! BlockSize");
// divide work
constexpr auto data_per_cluster_per_dims = SubLengths{} * DataClusterLengths{};
static_for<0, nDim, 1>{}([&](auto IDim_) {
constexpr auto IDim = decltype(IDim_){};
static_assert(SliceLengths::Get(IDim) % SubLengths::Get(IDim) == 0,
"wrong! cannot evenly divide sliced tensor into sub-tensor");
static_assert(SliceLengths::Get(IDim) % data_per_cluster_per_dims.Get(IDim) == 0,
"wrong! cannot evenly divide sliced tensor into cluster");
});
constexpr auto repeat_lengths = SliceLengths{} / data_per_cluster_per_dims;
// for now, only support SubLengths.Get() == 1 on a merged dimension that constains
// multiple original dimensions
static_for<0, nDim, 1>{}([&](auto IDim_) {
constexpr auto IDim = decltype(IDim_){};
static_assert(SubLengths::Get(IDim) == 1 ||
(!SrcDesc::ContainMultipleOriginalDimensions(IDim) &&
!DstDesc::ContainMultipleOriginalDimensions(IDim)),
"wrong! only surpport Sub-Length == 1 on a merged dimension");
});
// calculate mThreadSrcOffset, mThreadDstOffset
const auto thread_cluster_multi_id =
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
const auto data_cluster_multi_id =
reorder_array_given_old2new(thread_cluster_multi_id, ThreadClusterArrangeOrder{});
const auto thread_data_multi_id_begin = data_cluster_multi_id * SubLengths{};
// original multi-id
mThreadSrcOriginalMultiId = SrcDesc::GetOriginalMultiIndexFromMultiIndex(
src_block_data_multi_id_begin + thread_data_multi_id_begin);
mThreadDstOriginalMultiId = DstDesc::GetOriginalMultiIndexFromMultiIndex(
dst_block_data_multi_id_begin + thread_data_multi_id_begin);
// partial offset on each dimension
static_for<0, nDim, 1>{}([&](auto IDim_) {
constexpr auto IDim = decltype(IDim_){};
constexpr index_t idim = IDim.Get();
constexpr auto src_partial_original_dims =
SrcDesc::GetContainedOriginalDimensions(IDim);
constexpr auto src_partial_original_desc =
SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims);
mThreadSrcPartialOffsets(idim) = src_partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims));
});
static_for<0, nDim, 1>{}([&](auto IDim_) {
constexpr auto IDim = decltype(IDim_){};
constexpr index_t idim = IDim.Get();
constexpr auto dst_partial_original_dims =
DstDesc::GetContainedOriginalDimensions(IDim);
constexpr auto dst_partial_original_desc =
DstDesc::GetOriginalTensorDescriptor().Extract(dst_partial_original_dims);
mThreadDstPartialOffsets(idim) = dst_partial_original_desc.GetOffsetFromMultiIndex(
extract_array(mThreadDstOriginalMultiId, dst_partial_original_dims));
});
// complete offset
mThreadSrcOffset = accumulate_on_array(
mThreadSrcPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
mThreadDstOffset = accumulate_on_array(
mThreadDstPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
#if 0
if(get_block_1d_id() == 0)
{
printf("id %5u %5u: "
"src_block_data_multi_id_begin: %u %u %u %u, "
"thread_cluster_multi_id: %u %u %u %u, "
"data_cluster_multi_id: %u %u %u %u, "
"thread_data_multi_id_begin: %u %u %u %u, "
"mThreadSrcOffset %u, mThreadDstOffset %u \n",
get_block_1d_id(),
get_thread_local_1d_id(),
src_block_data_multi_id_begin[0],
src_block_data_multi_id_begin[1],
src_block_data_multi_id_begin[2],
src_block_data_multi_id_begin[3],
thread_cluster_multi_id[0],
thread_cluster_multi_id[1],
thread_cluster_multi_id[2],
thread_cluster_multi_id[3],
data_cluster_multi_id[0],
data_cluster_multi_id[1],
data_cluster_multi_id[2],
data_cluster_multi_id[3],
thread_data_multi_id_begin[0],
thread_data_multi_id_begin[1],
thread_data_multi_id_begin[2],
thread_data_multi_id_begin[3],
mThreadSrcOffset,
mThreadDstOffset);
}
#endif
}
__device__ static constexpr index_t GetRegisterClipboardSize()
{
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * DataClusterLengths{});
constexpr auto thread_tensor_desc =
make_ConstantTensorDescriptor_packed(SubLengths{} * repeat_lengths);
return thread_tensor_desc.GetElementSpace();
}
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
Float* __restrict__ p_clipboard) const
{
constexpr auto thread_sub_tensor_lengths = SubLengths{};
constexpr auto data_per_cluster_per_dims = thread_sub_tensor_lengths * DataClusterLengths{};
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * DataClusterLengths{});
constexpr auto thread_tensor_desc =
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
#if 0
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
const auto src_thread_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
const index_t src_offset =
SrcDesc{}.GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
const index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
#else // HIP compiler performs better with these codes
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
constexpr auto src_thread_data_multi_id_begin =
repeat_multi_id * data_per_cluster_per_dims;
constexpr auto clipboard_data_multi_id_begin =
repeat_multi_id * thread_sub_tensor_lengths;
constexpr index_t src_offset =
SrcDesc::GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
constexpr index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
#endif
threadwise_generic_tensor_slice_copy_v1(SrcDesc{},
p_src + src_offset + mThreadSrcOffset,
make_zero_array<index_t, nDim>(),
thread_tensor_desc,
p_clipboard + clipboard_offset,
make_zero_array<index_t, nDim>(),
thread_sub_tensor_lengths,
SrcAccessOrder{},
Number<SrcDataPerRead>{});
});
}
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const
{
constexpr auto thread_sub_tensor_lengths = SubLengths{};
constexpr auto data_per_cluster_per_dims = thread_sub_tensor_lengths * DataClusterLengths{};
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * DataClusterLengths{});
constexpr auto thread_tensor_desc =
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
#if 0
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
const auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
const index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin);
#else // HIP compiler performs better with these codes
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
constexpr auto clipboard_data_multi_id_begin =
repeat_multi_id * thread_sub_tensor_lengths;
constexpr auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
constexpr index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
constexpr index_t dst_offset =
DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin);
#endif
threadwise_generic_tensor_slice_copy_v1(thread_tensor_desc,
p_clipboard + clipboard_offset,
make_zero_array<index_t, nDim>(),
DstDesc{},
p_dst + dst_offset + mThreadDstOffset,
make_zero_array<index_t, nDim>(),
thread_sub_tensor_lengths,
DstAccessOrder{},
Number<DstDataPerWrite>{});
});
}
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
Float p_clipboard[GetRegisterClipboardSize()];
RunLoadRegisterClipboard(p_src, p_clipboard);
RunStoreRegisterClipboard(p_clipboard, p_dst);
}
// When moving the slicing windows along a merged dimension, if the strides of the
// contained (by the merged dimension) original dimensions are in descending order,
// then there is no guarantee that the new offset will be larger than the old offset
// for movement in positive direction (vice versue for movement in negative direction).
// As a result, there is the possiblity that the offset calculation may result in
// unsigned integer underflow (due to "-" operation). However, this hazard should not
// happen, as long as the users make sure the slicing window would not be moved out of
// the boundary of the tensor being sliced. This functions doesn't do runtime sanity
// check on out-of-bound slicing window, for performance reason
template <index_t IDim_, index_t StepSize, bool PositiveDirection>
__device__ void MoveSlicingWindowOnSourceTensor(
Number<IDim_>, Number<StepSize>, integral_constant<bool, PositiveDirection> direction)
{
constexpr auto IDim = Number<IDim_>{};
constexpr index_t idim = IDim.Get();
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim)>{}([&](auto fwd) {
// logic for a merged dimension, also works for non-merged dimension, but its logic may
// be unncessarily complicated for compiler to remove calculations that are useless for
// a non-merged dimension
// extract partial original dimensions
constexpr auto src_partial_original_dims =
SrcDesc::GetContainedOriginalDimensions(IDim);
constexpr auto src_partial_original_desc =
SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims);
// calculate new partial original multi-id
auto old_src_partial_original_multi_id =
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims);
auto new_src_partial_original_multi_id =
src_partial_original_desc.UpdateMultiIndexGivenStepSizeOf1dIndex(
old_src_partial_original_multi_id, StepSize, direction);
// update "mThreadSrcOriginalMultiId"
static_for<0, decltype(src_partial_original_dims)::GetSize(), 1>{}([&](auto I_) {
constexpr auto I = decltype(I_){};
constexpr index_t idim_original = src_partial_original_dims.Get(I);
mThreadSrcOriginalMultiId(idim_original) =
new_src_partial_original_multi_id[I.Get()];
});
// calculate new partial offset on this merged dimension
const index_t old_src_partial_offset = mThreadSrcPartialOffsets[idim];
const index_t new_src_partial_offset =
src_partial_original_desc.GetOffsetFromMultiIndex(
new_src_partial_original_multi_id);
// update "mThreadSrcPartialOffsets"
mThreadSrcPartialOffsets(idim) = new_src_partial_offset;
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
mThreadSrcOffset = (mThreadSrcOffset + new_src_partial_offset) - old_src_partial_offset;
}).Else([&](auto fwd) {
// Logic for non-merged dimension. If you are never going to move the slicing window on
// a merged dimension, then "mThreadSrcOriginalMultiId" and "mThreadSrcPartialOffsets",
// which are being calculated here, will never be used later. In this case, compiler
// should be able to remove these calculations.
// TODO: make sure compiler would actually remove them in this case.
// It is the user's responsiblity to make sure the slicing window will not be moved out
// of the boundary of the tensor being sliced. Otherwise, there might be hazard like
// unsigned integer underflow. That is NO runtime sanity check to prevent the hazard
constexpr index_t idim_original = SrcDesc::GetContainedOriginalDimensions(IDim).Front();
static_if<PositiveDirection>{}([&](auto fwd) {
mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim);
mThreadSrcOriginalMultiId(idim_original) += StepSize;
mThreadSrcPartialOffsets(idim) += StepSize * fwd(SrcDesc{}).GetStride(IDim);
}).Else([&](auto fwd) {
mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
mThreadSrcOriginalMultiId(idim_original) -= StepSize;
mThreadSrcPartialOffsets(idim) -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
});
});
}
};
} // namespace ck
#endif

View File

@@ -0,0 +1,299 @@
#ifndef CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "threadwise_tensor_slice_copy.hpp"
namespace ck {
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class SrcLengths,
class SrcSubLengths,
class SrcClusterLengths,
class MapDst2Src,
class MapThreadCluster2SrcCluster,
index_t SrcDataPerRead,
index_t DstDataPerWrite>
struct BlockwiseTensorSliceReorderCopy_v3
{
static constexpr index_t nDim = SrcLengths::GetSize();
index_t mThreadSrcOffset;
index_t mThreadDstOffset;
__device__
BlockwiseTensorSliceReorderCopy_v3(Array<index_t, nDim> src_block_data_multi_id_begin,
Array<index_t, nDim> dst_block_data_multi_id_begin)
{
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto src_lengths = SrcLengths{};
constexpr auto map_dst2src = MapDst2Src{};
constexpr auto src_sub_lengths = SrcSubLengths{};
constexpr auto dst_sub_lengths = src_sub_lengths.ReorderGivenNew2Old(map_dst2src);
constexpr auto map_thread_cluster_2_src_cluster = MapThreadCluster2SrcCluster{};
constexpr auto src_cluster_lengths = SrcClusterLengths{};
constexpr auto thread_cluster_lengths =
src_cluster_lengths.ReorderGivenNew2Old(map_thread_cluster_2_src_cluster);
constexpr auto thread_cluster_desc =
make_ConstantTensorDescriptor_packed(thread_cluster_lengths);
// sanity check: data type
static_assert(is_same<Float, float>::value, "wrong! only support float for now!\n");
// sanity check: nDim
static_assert(SrcDesc::GetNumOfDimension() == nDim &&
DstDesc::GetNumOfDimension() == nDim && SrcLengths::GetSize() == nDim &&
SrcSubLengths::GetSize() == nDim &&
SrcClusterLengths::GetSize() == nDim && MapDst2Src::GetSize() == nDim &&
MapThreadCluster2SrcCluster::GetSize() == nDim,
"wrong! nDim is not consistent\n");
// sanity check: BlockSize
constexpr index_t num_active_thread = thread_cluster_desc.GetElementSize();
static_assert(BlockSize >= num_active_thread,
"wrong! BlockSize is not big enough for ThreadPerDims!");
// sanity check: work division
static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr auto I = decltype(IDim){};
constexpr index_t src_len = src_lengths.Get(I);
constexpr index_t src_sub_len = src_sub_lengths.Get(I);
constexpr index_t src_cluster_len = src_cluster_lengths.Get(I);
static_assert(src_len % (src_sub_len * src_cluster_len) == 0,
"wrong! cannot evenly divide Src tensor lengths");
});
// sanity check: src read
static_assert(SrcDataPerRead == 1 || SrcDataPerRead == 2 || SrcDataPerRead == 4,
"wrong! only support SrcDataPerRead == 1, 2 or 4!\n");
static_assert(SrcDataPerRead == 1 || src_desc.GetStride(Number<nDim - 1>{}) == 1,
"wrong! only support src.stride(nDim-1) == 1 if SrcDataPerRead > 1!\n");
static_assert(src_sub_lengths.Get(Number<nDim - 1>{}) % SrcDataPerRead == 0,
"wrong! src_sub_lengths[nDim-1] % SrcDataPerRead != 0\n");
static_assert(src_desc.GetStride(Number<nDim - 2>{}) % SrcDataPerRead == 0,
"wrong! should satisfy src_desc.stride(nDim-2) % SrcDataPerRead == 0, to "
"keep alignment");
// sanity check: dst write
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4,
"wrong! only support DstDataPerWrite == 1, 2 or 4!\n");
static_assert(DstDataPerWrite == 1 || dst_desc.GetStride(Number<nDim - 1>{}) == 1,
"wrong! only support dst.stride(nDim-1) == 1 if DstDataPerWrite > 1!\n");
static_assert(dst_sub_lengths.Get(Number<nDim - 1>{}) % DstDataPerWrite == 0,
"wrong! dst_sub_lengths[nDim-1] % DstDataPerWrite != 0\n");
static_assert(dst_desc.GetStride(Number<nDim - 2>{}) % DstDataPerWrite == 0,
"wrong! should satisfy dst_desc.stride(nDim-2) % DstDataPerWrite == 0, to "
"keep alignment");
// start dividing work
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
const auto thread_multi_id =
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
// compiler: thread_multi_id, src_data_multi_id, dst_data_multi_id, will use separate
// regsiters, or only one copy???
auto src_data_multi_id =
reorder_array_given_old2new(thread_multi_id, map_thread_cluster_2_src_cluster);
static_for<0, nDim, 1>{}([&](auto IDim) {
constexpr auto I = decltype(IDim){};
constexpr index_t i = I.Get();
// compiler: will it really compute index here, or be merged with
// GetOffsetFromMultiIndex and
// optimized away???
src_data_multi_id(i) *= src_sub_lengths.Get(I);
});
// compiler: will it really compute index here, or be merged with GetOffsetFromMultiIndex
// and
// optimized away???
const auto dst_data_multi_id = reorder_array_given_new2old(src_data_multi_id, map_dst2src);
mThreadSrcOffset =
src_desc.GetOffsetFromMultiIndex(src_data_multi_id + src_block_data_multi_id_begin);
mThreadDstOffset =
dst_desc.GetOffsetFromMultiIndex(dst_data_multi_id + dst_block_data_multi_id_begin);
#if 0
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(thread_cluster_desc, "thread_cluster_desc: ");
}
if(get_block_1d_id() == 0)
{
printf("id %5u %5u: "
"thread_multi_id: %u %u, "
"src_block_data_multi_id_begin: %u %u, "
"src_data_multi_id: %u %u, "
"mThreadSrcOffset %u, mThreadDstOffset %u \n",
get_block_1d_id(),
get_thread_local_1d_id(),
thread_multi_id[0],
thread_multi_id[1],
src_block_data_multi_id_begin[0],
src_block_data_multi_id_begin[1],
src_data_multi_id[0],
src_data_multi_id[1],
mThreadSrcOffset,
mThreadDstOffset);
}
#endif
}
__device__ static constexpr index_t GetRegisterClipboardSize()
{
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
constexpr auto src_data_per_cluster_per_dims =
thread_sub_tensor_lengths * SrcClusterLengths{};
constexpr auto repeat_lengths = transform_sequences(
math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
constexpr auto thread_tensor_desc =
make_ConstantTensorDescriptor_packed(thread_tensor_lengths);
return thread_tensor_desc.GetElementSpace();
}
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
Float* __restrict__ p_clipboard) const
{
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
constexpr auto src_data_per_cluster_per_dims =
thread_sub_tensor_lengths * SrcClusterLengths{};
constexpr auto repeat_lengths = transform_sequences(
math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
constexpr auto thread_tensor_desc =
make_ConstantTensorDescriptor_packed(thread_tensor_lengths);
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
constexpr auto src_data_multi_id = repeat_multi_id * src_data_per_cluster_per_dims;
constexpr auto clipboard_data_multi_id = repeat_multi_id * thread_sub_tensor_lengths;
constexpr index_t src_offset = SrcDesc{}.GetOffsetFromMultiIndex(src_data_multi_id);
constexpr index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id);
threadwise_tensor_slice_copy(SrcDesc{},
p_src + src_offset + mThreadSrcOffset,
thread_tensor_desc,
p_clipboard + clipboard_offset,
thread_sub_tensor_lengths,
Number<SrcDataPerRead>{});
});
}
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const
{
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
constexpr auto src_data_per_cluster_per_dims =
thread_sub_tensor_lengths * SrcClusterLengths{};
constexpr auto repeat_lengths = transform_sequences(
math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
constexpr auto thread_tensor_desc =
make_ConstantTensorDescriptor_packed(thread_tensor_lengths);
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
constexpr auto clipboard_data_multi_id = repeat_multi_id * thread_sub_tensor_lengths;
constexpr auto src_data_multi_id = repeat_multi_id * src_data_per_cluster_per_dims;
// reorder src_data_multi_id to get dst_data_multi_id
constexpr auto dst_data_multi_id = src_data_multi_id.ReorderGivenNew2Old(MapDst2Src{});
constexpr index_t clipboard_offset =
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id);
constexpr index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id);
// write in the order of dst
#if 1
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(thread_tensor_desc,
p_clipboard + clipboard_offset,
DstDesc{},
p_dst + dst_offset +
mThreadDstOffset,
thread_sub_tensor_lengths,
MapDst2Src{});
#else
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(thread_tensor_desc,
p_clipboard + clipboard_offset,
DstDesc{},
p_dst + dst_offset +
mThreadDstOffset,
thread_sub_tensor_lengths,
MapDst2Src{},
Number<DstDataPerWrite>{});
#endif
});
}
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
Float p_clipboard[GetRegisterClipboardSize()];
RunLoadRegisterClipboard(p_src, p_clipboard);
RunStoreRegisterClipboard(p_clipboard, p_dst);
}
// this function doesn't do santiy check on whether the slicing window is out of the boundary
// of the tensor being sliced
template <index_t IDim_, index_t StepSize, bool PositiveDirection>
__device__ void MoveSlicingWindowOnSourceTensor(
Number<IDim_>, Number<StepSize>, integral_constant<bool, PositiveDirection> direction)
{
constexpr auto IDim = Number<IDim_>{};
static_if<PositiveDirection>{}([&](auto fwd) {
mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim);
}).Else([&](auto fwd) { mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim); });
}
};
} // namespace ck
#endif

View File

@@ -0,0 +1,60 @@
#ifndef CK_THREADWISE_4D_TENSOR_OP_HPP
#define CK_THREADWISE_4D_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace ck {
template <class Float, class Desc, class IDim, class NShift>
__device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto desc = Desc{};
#if 0
if(get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
}
#endif
constexpr index_t nshift = NShift::mValue;
constexpr index_t did0_end =
is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
constexpr index_t did1_end =
is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
constexpr index_t did2_end =
is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - nshift : desc.GetLength(I2);
constexpr index_t did3_end =
is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - nshift : desc.GetLength(I3);
for(index_t did0 = 0; did0 < did0_end; ++did0)
{
for(index_t did1 = 0; did1 < did1_end; ++did1)
{
for(index_t did2 = 0; did2 < did2_end; ++did2)
{
for(index_t did3 = 0; did3 < did3_end; ++did3)
{
const index_t dindex = desc.GetOffsetFromMultiIndex(did0, did1, did2, did3);
const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
p[dindex] = p[sindex];
}
}
}
}
}
} // namespace ck
#endif

View File

@@ -0,0 +1,228 @@
#ifndef CK_THREADWISE_DIRECT_CONVOLUTION_HPP
#define CK_THREADWISE_DIRECT_CONVOLUTION_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "threadwise_tensor_slice_copy.hpp"
namespace ck {
// optimized for scenario if p_in, p_wei, p_out are in register
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
__device__ void threadwise_direct_convolution_1(InDesc,
TInWei* const __restrict__ p_in,
WeiDesc,
TInWei* const __restrict__ p_wei,
OutDesc,
TOut* __restrict__ p_out)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{};
#if 0
if(blockIdx.x == 0 && get_thread_local_1d_id() == 0)
{
print_ConstantTensorDescriptor(in_desc, "threadwise_direct_convolution: in_desc: ");
print_ConstantTensorDescriptor(wei_desc, "threadwise_direct_convolution: wei_desc: ");
print_ConstantTensorDescriptor(out_desc, "threadwise_direct_convolution: out_desc: ");
}
#endif
for(index_t n = 0; n < out_desc.GetLength(I0); ++n)
{
for(index_t k = 0; k < out_desc.GetLength(I1); ++k)
{
for(index_t ho = 0; ho < out_desc.GetLength(I2); ++ho)
{
for(index_t wo = 0; wo < out_desc.GetLength(I3); ++wo)
{
for(index_t c = 0; c < wei_desc.GetLength(I1); ++c)
{
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
{
for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
{
const index_t hi = ho + y;
const index_t wi = wo + x;
const index_t in_index =
in_desc.GetOffsetFromMultiIndex(n, c, hi, wi);
const index_t wei_index =
wei_desc.GetOffsetFromMultiIndex(k, c, y, x);
const index_t out_index =
out_desc.GetOffsetFromMultiIndex(n, k, ho, wo);
fused_multiply_accumulate(
p_out[out_index], p_wei[wei_index], p_in[in_index]);
}
}
}
}
}
}
}
}
// Optimized for scenario if p_in and p_wei are in LDS, p_out are in register
// Copy in and wei into register before doing convolution
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
__device__ void threadwise_direct_convolution_2(InDesc,
TInWei* const __restrict__ p_in,
WeiDesc,
TInWei* const __restrict__ p_wei,
OutDesc,
TOut* __restrict__ p_out)
{
constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{};
constexpr auto in_reg_desc = make_ConstantTensorDescriptor_packed(in_desc.GetLengths());
constexpr auto wei_reg_desc = make_ConstantTensorDescriptor_packed(wei_desc.GetLengths());
// register
TInWei p_in_reg[in_reg_desc.GetElementSpace()];
TInWei p_wei_reg[wei_reg_desc.GetElementSpace()];
// copy input tensor into register
threadwise_tensor_slice_copy(
in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths(), Number<1>{});
// copy input tensor into register
threadwise_tensor_slice_copy(
wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths(), Number<1>{});
// do convolution
threadwise_direct_convolution_1(
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
}
// optimized for scenario where p_in and p_wei are in LDS, p_out is in register
// break down a non-1x1 convolution into a sequence of 1x1 convolutions,
// load 1x1 weight into register, and do 1x1 convolution in register.
template <class Data, class InDesc, class WeiDesc, class OutDesc>
__device__ void threadwise_direct_convolution_3(InDesc,
Data* const __restrict__ p_in,
WeiDesc,
Data* const __restrict__ p_wei,
OutDesc,
Data* __restrict__ p_out)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_desc = InDesc{};
constexpr auto wei_desc = WeiDesc{};
constexpr auto out_desc = OutDesc{};
constexpr auto in_reg_desc = make_ConstantTensorDescriptor(Sequence<in_desc.GetLength(I0),
in_desc.GetLength(I1),
out_desc.GetLength(I2),
out_desc.GetLength(I3)>{});
constexpr auto wei_reg_desc = make_ConstantTensorDescriptor(
Sequence<wei_desc.GetLength(I0), wei_desc.GetLength(I1), 1, 1>{});
Data p_in_reg[in_reg_desc.GetElementSpace()];
Data p_wei_reg[wei_reg_desc.GetElementSpace()];
constexpr index_t in_w_new_read = 1;
constexpr auto in_desc_reg_new_read =
make_ConstantTensorDescriptor(Sequence<in_reg_desc.GetLength(I0),
in_reg_desc.GetLength(I1),
in_reg_desc.GetLength(I2),
in_w_new_read>{});
#if 0
// this verison reused old input data in register, and read new data from LDS
// loop over vertical direction
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
{
// read first input
threadwise_4d_tensor_copy(in_desc,
p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, 0),
in_reg_desc,
p_in_reg,
in_reg_desc.GetLengths());
// read first 1x1 weight
threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, 0),
wei_reg_desc,
p_wei_reg,
wei_reg_desc.GetLengths());
// do first 1x1 conv
threadwise_direct_convolution_1(
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
// loop over horizontal direction
for(index_t x = 1; x < wei_desc.GetLength(I3); ++x)
{
// read new weight
threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, x),
wei_reg_desc,
p_wei_reg,
wei_reg_desc.GetLengths());
// shift old input to the left
threadwise_4d_tensor_shift_down(in_reg_desc, p_in_reg, I3, Number<in_w_new_read>{});
// read new input
threadwise_4d_tensor_copy(
in_desc,
p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, x + in_reg_desc.GetLength(I3) - 1),
in_reg_desc,
p_in_reg +
in_reg_desc.GetOffsetFromMultiIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read),
in_desc_reg_new_read.GetLengths());
// do 1x1 conv
threadwise_direct_convolution_1(
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
}
}
#elif 1
// this version read all input from LDS when filter moves
// loop over vertical direction
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
{
// loop over horizontal direction
for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
{
// read new weight
threadwise_4d_tensor_copy(wei_desc,
p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, x),
wei_reg_desc,
p_wei_reg,
wei_reg_desc.GetLengths());
// read new input
threadwise_4d_tensor_copy(in_desc,
p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, x),
in_reg_desc,
p_in_reg,
in_reg_desc.GetLengths());
// do 1x1 conv
threadwise_direct_convolution_1(
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
}
}
#endif
}
} // namespace ck
#endif

View File

@@ -0,0 +1,123 @@
#ifndef CK_THREADWISE_GEMM_HPP
#define CK_THREADWISE_GEMM_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
namespace ck {
template <class Float, class Matrix>
__device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread)
{
for(index_t i = 0; i < Matrix::NRow(); ++i)
{
for(index_t j = 0; j < Matrix::NCol(); ++j)
{
const index_t id = Matrix::GetOffsetFromMultiIndex(i, j);
p_thread[id] = Float(0);
}
}
}
template <class Float,
class SrcMatrix,
class DstMatrix,
index_t NRow,
index_t NCol,
index_t DataPerRead>
__device__ void threadwise_matrix_copy(SrcMatrix,
const Float* __restrict__ p_src,
DstMatrix,
Float* __restrict__ p_dst,
Sequence<NRow, NCol>,
Number<DataPerRead>)
{
static_assert(NCol % DataPerRead == 0, "wrong! should be NCol % == DataPerRead == 0");
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
constexpr auto src_mtx = SrcMatrix{};
constexpr auto dst_mtx = DstMatrix{};
for(index_t i = 0; i < NRow; ++i)
{
for(index_t j = 0; j < NCol; j += DataPerRead)
{
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
}
}
}
template <class MatrixA,
class MatrixB,
class MatrixC,
bool TransA,
bool TransB,
bool TransC,
class FloatA,
class FloatB,
class FloatC>
__device__ void threadwise_gemm(MatrixA,
integral_constant<bool, TransA>,
const FloatA* __restrict__ p_a_thread,
MatrixB,
integral_constant<bool, TransB>,
const FloatB* __restrict__ p_b_thread,
MatrixC,
integral_constant<bool, TransC>,
FloatC* __restrict__ p_c_thread)
{
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("p_a_thread: %f %f %f %f\n",
p_a_thread[0],
p_a_thread[1],
p_a_thread[2],
p_a_thread[3]);
printf("p_b_thread: %f %f %f %f\n",
p_b_thread[0],
p_b_thread[1],
p_b_thread[2],
p_b_thread[3]);
}
#endif
if(TransA && (!TransB) && (!TransC))
{
constexpr auto a_mtx = MatrixA{};
constexpr auto b_mtx = MatrixB{};
constexpr auto c_mtx = MatrixC{};
constexpr index_t M = c_mtx.NRow();
constexpr index_t N = c_mtx.NCol();
constexpr index_t K = a_mtx.NRow(); // A is transposed
for(index_t k = 0; k < K; ++k)
{
for(index_t i = 0; i < M; ++i)
{
for(index_t j = 0; j < N; ++j)
{
const index_t aindex = a_mtx.GetOffsetFromMultiIndex(k, i); // A is transposed
const index_t bindex = b_mtx.GetOffsetFromMultiIndex(k, j);
const index_t cindex = c_mtx.GetOffsetFromMultiIndex(i, j);
p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex];
}
}
}
}
else
{
// not implemented
assert(false);
}
}
} // namespace ck
#endif

View File

@@ -0,0 +1,20 @@
#ifndef CK_THREADWISE_GENERIC_TENSOR_OP_HPP
#define CK_THREADWISE_GENERIC_TENSOR_OP_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
namespace ck {
template <class Float, class TDesc>
__device__ void threadwise_generic_tensor_set_zero(TDesc, Float* __restrict__ p)
{
static_ford<decltype(TDesc::GetLengths())>{}([&](auto multi_id) {
constexpr index_t offset = TDesc::GetOffsetFromMultiIndex(multi_id);
p[offset] = static_cast<Float>(0);
});
}
} // namespace ck
#endif

View File

@@ -0,0 +1,107 @@
#ifndef CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#define CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp"
namespace ck {
template <class Float,
class SrcDesc,
class DstDesc,
class SliceLengths,
class DimAccessOrder,
index_t DataPerAccess>
__device__ void threadwise_generic_tensor_slice_copy_v1(
SrcDesc,
const Float* __restrict__ p_src,
Array<index_t, SrcDesc::GetNumOfDimension()> src_multi_id_begin,
DstDesc,
Float* __restrict__ p_dst,
Array<index_t, DstDesc::GetNumOfDimension()> dst_multi_id_begin,
SliceLengths,
DimAccessOrder,
Number<DataPerAccess>)
{
constexpr index_t nDim = SrcDesc::GetNumOfDimension();
static_assert(nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() &&
nDim == SliceLengths::GetSize() && nDim == DimAccessOrder::GetSize(),
"wrong! # of dimensions not the same");
static_assert(is_valid_sequence_map<DimAccessOrder>::value, "wrong! map is not valid");
#if 0
// doesn't compile, because merged-tensor reordering is not implemented
// TODO: implement tensor desc ops for merged-tensor
constexpr auto src_strides_in_access_order =
SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{});
constexpr auto dst_strides_in_access_order =
SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{});
// check src/dst stride on the lowest access dimension
static_assert((DataPerAccess == 1 || src_strides_in_access_order.Back() == 1) &&
(DataPerAccess == 1 || dst_strides_in_access_order.Back() == 1),
"wrong! src/dst stride on the lowest access dimension needs to be 1 for "
"vectorized read/write");
#endif
constexpr auto slice_lengths_in_access_order =
SliceLengths::ReorderGivenNew2Old(DimAccessOrder{});
// check slice length on the lowest access dimension
static_assert(slice_lengths_in_access_order.Back() % DataPerAccess == 0,
"wrong! slice length on the lowest access dimension should be evenly divided by "
"DataPerAccess");
constexpr index_t num_access_on_lowest_access_dimension =
slice_lengths_in_access_order.Back() / DataPerAccess;
constexpr auto access_lengths = slice_lengths_in_access_order.Modify(
Number<nDim - 1>{}, Number<num_access_on_lowest_access_dimension>{});
using vector_t = typename vector_type<Float, DataPerAccess>::MemoryType;
#if 1
ford<decltype(access_lengths)>{}([&](auto access_multi_id) {
auto data_multi_id_in_access_order = access_multi_id;
data_multi_id_in_access_order(nDim - 1) = access_multi_id[nDim - 1] * DataPerAccess;
const auto data_multi_id =
reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{});
const index_t src_index =
SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id);
const index_t dst_index =
DstDesc::GetOffsetFromMultiIndex(dst_multi_id_begin + data_multi_id);
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
});
#else
static_ford<decltype(access_lengths)>{}([&](auto access_multi_id) {
constexpr index_t itmp = access_multi_id.Back() * DataPerAccess;
constexpr auto data_multi_id_in_access_order =
access_multi_id.Modify(Number<nDim - 1>{}, Number<itmp>{});
constexpr auto data_multi_id = reorder_array_given_old2new(
sequence2array(data_multi_id_in_access_order), DimAccessOrder{});
const index_t src_index =
SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id);
const index_t dst_index =
DstDesc::GetOffsetFromMultiIndex(dst_multi_id_begin + data_multi_id);
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
});
#endif
}
} // namespace ck
#endif

View File

@@ -0,0 +1,202 @@
#ifndef CK_THREADWISE_TENSOR_SLICE_COPY_HPP
#define CK_THREADWISE_TENSOR_SLICE_COPY_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor.hpp"
namespace ck {
// need to assume src and dst is aligned
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
__device__ void threadwise_tensor_slice_copy(SrcDesc,
const Float* __restrict__ p_src,
DstDesc,
Float* __restrict__ p_dst,
SrcOpLengths,
Number<DataPerRead>)
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
constexpr index_t nDim = SrcOpLengths::GetSize();
static_assert(SrcDesc{}.GetNumOfDimension() == nDim && DstDesc{}.GetNumOfDimension() == nDim,
"wrong! dimension not consistent");
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto ref_desc = make_ConstantTensorDescriptor_packed(SrcOpLengths{});
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantTensorDescriptor(src_desc, "src_desc");
print_ConstantTensorDescriptor(dst_desc, "dst_desc");
print_ConstantTensorDescriptor(ref_desc, "ref_desc");
}
#endif
static_assert(DataPerRead == 1 || (SrcDesc{}.GetStride(Number<nDim - 1>{}) == 1 &&
DstDesc{}.GetStride(Number<nDim - 1>{}) == 1),
"wrong! only support stride[nDim-1] == 1!\n");
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
"wrong! only support DataPerRead == 1, 2 or 4!\n");
static_assert(
SrcDesc{}.GetStride(Number<nDim - 2>{}) % DataPerRead == 0 &&
DstDesc{}.GetStride(Number<nDim - 2>{}) % DataPerRead == 0,
"wrong! src and dst stride[nDim-2] should be multiple of DataPerRead to keep alignment");
constexpr index_t L_Back = SrcOpLengths{}.Back();
static_assert(L_Back % DataPerRead == 0,
"wrong! lengths[nDim-1] should be evenly divided by DataPerRead");
constexpr index_t nRead = L_Back / DataPerRead;
static_ford<decltype(ref_desc.GetLengths().PopBack())>{}([=](auto Ids) {
static_for<0, nRead, 1>{}([&](auto IRead) {
constexpr auto multi_id = decltype(Ids){}.PushBack(Number<IRead.Get() * DataPerRead>{});
const index_t src_index = src_desc.GetOffsetFromMultiIndex(multi_id);
const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(multi_id);
*(reinterpret_cast<vector_t*>(&p_dst[dst_index])) =
*(reinterpret_cast<const vector_t*>(&p_src[src_index]));
});
});
}
// access in order of src
template <class SrcData,
class DstData,
class SrcDesc,
class DstDesc,
class SrcOpLengths,
class MapDst2Src>
__device__ void
threadwise_tensor_slice_copy_reorder_given_dst2src_v1(SrcDesc,
const SrcData* __restrict__ p_src,
DstDesc,
DstData* __restrict__ p_dst,
SrcOpLengths,
MapDst2Src)
{
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
ford<SrcOpLengths>{}([&](auto src_multi_id) {
const auto dst_multi_id = reorder_array_given_new2old(src_multi_id, MapDst2Src{});
const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
const index_t src_index = src_desc.GetOffsetFromMultiIndex(src_multi_id);
p_dst[dst_index] = p_src[src_index];
});
}
// access in order of dst
template <class SrcData,
class DstData,
class SrcDesc,
class DstDesc,
class SrcOpLengths,
class MapDst2Src>
__device__ void
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(SrcDesc,
const SrcData* __restrict__ p_src,
DstDesc,
DstData* __restrict__ p_dst,
SrcOpLengths,
MapDst2Src)
{
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto dst_op_lengths = SrcOpLengths{}.ReorderGivenNew2Old(MapDst2Src{});
ford<decltype(dst_op_lengths)>{}([&](auto dst_multi_id) {
const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{});
const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
const index_t src_index = src_desc.GetOffsetFromMultiIndex(src_multi_id);
p_dst[dst_index] = p_src[src_index];
});
}
// access in order of dst
// manually pack data into vector before write
template <class Float,
class SrcDesc,
class DstDesc,
class SrcOpLengths,
class MapDst2Src,
index_t DstDataPerWrite>
__device__ void
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
const Float* __restrict__ p_src,
DstDesc,
Float* __restrict__ p_dst,
SrcOpLengths,
MapDst2Src,
Number<DstDataPerWrite>)
{
using vector_t = typename vector_type<Float, DstDataPerWrite>::MemoryType;
constexpr index_t nDim = SrcOpLengths::GetSize();
static_assert(DstDataPerWrite == 1 || DstDesc{}.GetStride(Number<nDim - 1>{}) == 1,
"wrong! only support dst.stride[nDim-1] == 1, if DstDataPerWrite != 1");
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4,
"wrong! only support DstDataPerWrite == 1, 2 or 4");
static_assert(
DstDesc{}.GetStride(Number<nDim - 2>{}) % DstDataPerWrite == 0,
"wrong! dst.stride[nDim-2] should be multiple of DstDataPerWrite to keep alignment");
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr auto dst_op_lengths = SrcOpLengths{}.ReorderGivenNew2Old(MapDst2Src{});
constexpr index_t L_Dst_Back = dst_op_lengths.Back();
static_assert(L_Dst_Back % DstDataPerWrite == 0,
"wrong! dst.lengths[nDim-1] should be evenly divided by DstDataPerWrite");
constexpr index_t nWrite = L_Dst_Back / DstDataPerWrite;
ford<decltype(dst_op_lengths.PopBack())>{}([&](auto ids) {
static_for<0, nWrite, 1>{}([&](auto IWrite) {
vector_t dst_vec_data;
// pack data
static_for<0, DstDataPerWrite, 1>{}([&](auto IDstData) {
const auto dst_multi_id =
ids.PushBack(IWrite.Get() * DstDataPerWrite + IDstData.Get());
const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{});
const index_t src_index = src_desc.GetOffsetFromMultiIndex(src_multi_id);
vector_type<Float, DstDataPerWrite>::SetScalar(
dst_vec_data, p_src[src_index], IDstData);
});
// write data
const auto dst_multi_id = ids.PushBack(IWrite.Get() * DstDataPerWrite);
const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
*(reinterpret_cast<vector_t*>(&p_dst[dst_index])) = dst_vec_data;
});
});
}
} // namespace ck
#endif