Reorganize files, Part 1 (#119)

* delete obselete files

* move files

* build

* update cmake

* update cmake

* fix build

* reorg examples

* update cmake for example and test
This commit is contained in:
Chao Liu
2022-03-08 21:46:36 -06:00
committed by GitHub
parent 245f741457
commit 5d37d7bff4
422 changed files with 388 additions and 3326 deletions

View File

@@ -0,0 +1,394 @@
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace ck {
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. AKMBlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. BKNBlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CM0M1N0N1ThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <
index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename AKMBlockDesc,
typename BKNBlockDesc,
index_t M1PerThreadM11,
index_t N1PerThreadN11,
index_t KPerThread,
index_t M1N1ThreadClusterM100,
index_t M1N1ThreadClusterN100,
index_t M1N1ThreadClusterM101,
index_t M1N1ThreadClusterN101,
index_t AThreadCopyScalarPerVector_M11,
index_t BThreadCopyScalarPerVector_N11,
typename enable_if<AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t K = AKMBlockDesc{}.GetLength(I0);
static constexpr index_t M = AKMBlockDesc{}.GetLength(I1);
static constexpr index_t N = BKNBlockDesc{}.GetLength(I1);
static constexpr index_t M100 = M1N1ThreadClusterM100;
static constexpr index_t N100 = M1N1ThreadClusterN100;
static constexpr index_t M101 = M1N1ThreadClusterM101;
static constexpr index_t N101 = M1N1ThreadClusterN101;
static constexpr index_t M11 = M1PerThreadM11;
static constexpr index_t N11 = N1PerThreadN11;
static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11;
static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11;
static constexpr index_t M0 = M / M1;
static constexpr index_t N0 = N / N1;
__host__ __device__ static constexpr auto
MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */)
{
const auto a_k_m0_m1_block_desc = transform_tensor_descriptor(
AKMBlockDesc{},
make_tuple(make_pass_through_transform(Number<K>{}),
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return a_k_m0_m1_block_desc;
}
__host__ __device__ static constexpr auto
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */)
{
const auto b_k_n0_n1_block_desc = transform_tensor_descriptor(
BKNBlockDesc{},
make_tuple(make_pass_through_transform(Number<K>{}),
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return b_k_n0_n1_block_desc;
}
__host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M, N]
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(
Number<M0>{}, Number<M100>{}, Number<M101>{}, Number<M11>{})),
make_unmerge_transform(make_tuple(
Number<N0>{}, Number<N100>{}, Number<N101>{}, Number<N11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor;
}
__host__ __device__ static constexpr auto
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M0, M1, N0, N1]
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(Number<M0>{}),
make_unmerge_transform(
make_tuple(Number<M100>{}, Number<M101>{}, Number<M11>{})),
make_pass_through_transform(Number<N0>{}),
make_unmerge_transform(
make_tuple(Number<N100>{}, Number<N101>{}, Number<N11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor;
}
__host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths()
{
return Sequence<M0, M11, N0, N11>{};
}
static constexpr auto a_k_m0_m1_block_desc_ = MakeAKM0M1BlockDescriptor(AKMBlockDesc{});
static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{});
public:
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2()
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
{
static_assert(AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == M101 * M100 * N101 * N100,
"wrong! blocksize and cluster size not consistent");
static_assert(M % M1 == 0 && N % N1 == 0, "wrong!");
static_assert(AKMBlockDesc{}.GetLength(I0) == BKNBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
// TODO: remove this restriction
static_assert(M0 == 2 && N0 == 2, "wrong");
}
__device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id)
{
// lower: [M0, M1, N0, N1]
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor();
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
// upper: [Tid, M0, M11, N0, N11]
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)),
make_pass_through_transform(M0),
make_pass_through_transform(M11),
make_pass_through_transform(N0),
make_pass_through_transform(N11)),
make_tuple(
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
}
__host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; }
__host__ __device__ static constexpr auto GetBBlockAlignment() { return N1PerThreadN11; }
template <typename CM0M1N0N1ThreadDesc,
typename ABlockBuffer,
typename BBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const CM0M1N0N1ThreadDesc& /* c_m0_m1_n0_n1_thread_desc */,
const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: remove this restriction
static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 &&
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
"wrong");
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>(
a_k_m0_m1_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>(
b_k_n0_n1_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm =
ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1<FloatA,
FloatB,
FloatC,
decltype(a_k_m0_m1_thread_desc_),
decltype(b_k_n0_n1_thread_desc_),
CM0M1N0N1ThreadDesc,
Sequence<KPerThread>,
Sequence<1, M1PerThreadM11>,
Sequence<1, N1PerThreadN11>>{};
// read A_sub_0
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
make_tuple(I0, I0, I0),
a_block_buf,
a_k_m0_m1_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
make_tuple(I0, I0, I0),
b_block_buf,
b_k_n0_n1_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
make_tuple(I0, I1, I0),
b_block_buf,
b_k_n0_n1_thread_desc_,
make_tuple(I0, I1, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
make_tuple(I0, I1, I0),
a_block_buf,
a_k_m0_m1_thread_desc_,
make_tuple(I0, I1, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
// loop over rest of k
static_for<KPerThread, K, KPerThread>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
make_tuple(k, I0, I0),
a_block_buf,
a_k_m0_m1_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// read B_sub_0
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
make_tuple(k, I0, I0),
b_block_buf,
b_k_n0_n1_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
// read B_sub_1
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
make_tuple(k, I1, I0),
b_block_buf,
b_k_n0_n1_thread_desc_,
make_tuple(I0, I1, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
make_tuple(k, I1, I0),
a_block_buf,
a_k_m0_m1_thread_desc_,
make_tuple(I0, I1, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
}
private:
// A[K, M0, M1]
static constexpr auto a_k_m0_m1_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}));
// B[K, N0, N1]
static constexpr auto b_k_n0_n1_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
decltype(a_k_m0_m1_block_desc_),
decltype(a_k_m0_m1_thread_desc_),
Sequence<KPerThread, 1, M1PerThreadM11>,
Sequence<0, 1, 2>,
2,
AThreadCopyScalarPerVector_M11,
1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_k_n0_n1_block_desc_),
decltype(b_k_n0_n1_thread_desc_),
Sequence<KPerThread, 1, N1PerThreadN11>,
Sequence<0, 1, 2>,
2,
BThreadCopyScalarPerVector_N11,
1>;
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,410 @@
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace ck {
// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc_BK0_BM_BK1,
typename BBlockDesc_BK0_BN_BK1,
index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11,
index_t BK0PerThread,
typename BM10BN10ThreadClusterBM10Xs, // Sequence<BM10BN10ThreadClusterBM100,
// BM10BN10ThreadClusterBM101, ...>
typename BM10BN10ThreadClusterBN10Xs, // Sequence<BM10BN10ThreadClusterBN100,
// BM10BN10ThreadClusterBN101, ...>
index_t AThreadCopyScalarPerVector_BM11,
index_t BThreadCopyScalarPerVector_BN11,
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0);
static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2);
static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0];
static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0];
static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1];
static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1];
static constexpr index_t BM11 = BM1PerThreadBM11;
static constexpr index_t BN11 = BN1PerThreadBN11;
static constexpr index_t BM1 = BM100 * BM101 * BM11;
static constexpr index_t BN1 = BN100 * BN101 * BN11;
static constexpr index_t BM0 = BM / BM1;
static constexpr index_t BN0 = BN / BN1;
__host__ __device__ static constexpr auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
{
const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
a_block_desc_bk0_bm_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
make_pass_through_transform(Number<BK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_block_bk0_bm0_bm1_bk1;
}
__host__ __device__ static constexpr auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
{
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
b_block_desc_bk0_bn_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
make_pass_through_transform(Number<BK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_block_desc_bk0_bn0_bn1_bk1;
}
__host__ __device__ static constexpr auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM, BN]
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(
Number<BM0>{}, Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
make_unmerge_transform(make_tuple(
Number<BN0>{}, Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n;
}
__host__ __device__ static constexpr auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM0, BM1, BN0, BN1]
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(Number<BM0>{}),
make_unmerge_transform(
make_tuple(Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
make_pass_through_transform(Number<BN0>{}),
make_unmerge_transform(
make_tuple(Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1;
}
__host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
{
return Sequence<BM0, BM11, BN0, BN11>{};
}
static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ =
MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{});
static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ =
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
public:
__device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)},
b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)}
{
static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == BM101 * BM100 * BN101 * BN100,
"wrong! blocksize and cluster size not consistent");
static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!");
static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) ==
BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
"wrong! K dimension not consistent");
// TODO remove this restriction
static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 &&
BM10BN10ThreadClusterBN10Xs::Size() == 2,
"wrong!");
// TODO: remove this restriction
static_assert(BM0 == 2 && BN0 == 2, "wrong");
}
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
{
// lower: [BM0, BM1, BN0, BN1]
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
constexpr auto adaptor0 =
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1();
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// upper: [Tid, BM0, BM11, BN0, BN11]
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)),
make_pass_through_transform(BM0),
make_pass_through_transform(BM11),
make_pass_through_transform(BN0),
make_pass_through_transform(BN11)),
make_tuple(
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
}
template <typename CThreadDesc_BM0_BM11_BN0_BN11,
typename ABlockBuffer,
typename BBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&,
const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: remove this restriction
static_assert(BM0 == 2 && BN0 == 2 &&
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 &&
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
"wrong");
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>(
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>(
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
constexpr auto threadwise_contraction =
ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
FloatA,
FloatB,
FloatC,
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
CThreadDesc_BM0_BM11_BN0_BN11,
Sequence<BK0PerThread, BK1>,
Sequence<1, BM1PerThreadBM11>,
Sequence<1, BN1PerThreadBN11>>{};
// read A_sub_0
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_block_buf,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_block_buf,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I1, I0, I0),
b_block_buf,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I1, I0, I0),
a_block_buf,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
// loop over rest of bk0
static_for<BK0PerThread, BK0, BK0PerThread>{}([&](auto bk0) {
// read A_sub_0
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(bk0, I0, I0, I0),
a_block_buf,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// read B_sub_0
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(bk0, I0, I0, I0),
b_block_buf,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
// read B_sub_1
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(bk0, I1, I0, I0),
b_block_buf,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(bk0, I1, I0, I0),
a_block_buf,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
}
private:
// A[BK0, BM0, BM1, BK1]
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ =
make_naive_tensor_descriptor_packed(make_tuple(
Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{}));
// B[BK0, BN0, BN1, BK1]
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ =
make_naive_tensor_descriptor_packed(make_tuple(
Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatA,
FloatA,
decltype(a_block_desc_bk0_bm0_bm1_bk1_),
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
Sequence<BK0PerThread, 1, BM1PerThreadBM11, BK1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatB,
FloatB,
decltype(b_block_desc_bk0_bn0_bn1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
Sequence<BK0PerThread, 1, BN1PerThreadBN11, BK1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, BN1PerThreadBN11, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,175 @@
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#include "common_header.hpp"
#include "threadwise_gemm_dlops_v3.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc_E1_K1_E2,
typename BBlockDesc_E1_N_Ho_Wo_E2,
typename CThreadDesc_K_N_Ho_Wo,
index_t EPerThreadLoop,
index_t KPerThreadLoop>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;
static constexpr auto E1 = ABlockDesc_E1_K1_E2{}.GetLength(I0);
static constexpr auto KPerBlock = ABlockDesc_E1_K1_E2{}.GetLength(I1);
static constexpr auto E2 = ABlockDesc_E1_K1_E2{}.GetLength(I2);
static constexpr auto HoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
static constexpr auto WoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
static constexpr auto KPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I0);
static constexpr auto HoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2);
static constexpr auto WoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3);
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadLoop>{}, Number<E2>{}));
static constexpr auto b_thread_mtx_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<EPerThreadLoop>{},
Number<1>{},
Number<HoPerThread>{},
Number<WoPerThread>{},
Number<E2>{}));
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadLoop>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_origin_data_idx_{GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * KPerThread, 0)}
{
static_assert(ABlockDesc_E1_K1_E2::IsKnownAtCompileTime() &&
BBlockDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(
ABlockDesc_E1_K1_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) &&
ABlockDesc_E1_K1_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4),
"wrong! E dimension not consistent\n");
static_assert(E1 % EPerThreadLoop == 0, "");
static_assert(KPerThread % KPerThreadLoop == 0, "");
static_assert(KPerBlock % KPerThread == 0 && HoPerBlock % HoPerThread == 0 &&
WoPerBlock % WoPerThread == 0,
"wrong! Cannot evenly divide work among\n");
constexpr auto KThreadCluster = KPerBlock / KPerThread;
constexpr auto HThreadCluster = HoPerBlock / HoPerThread;
constexpr auto WThreadCluster = WoPerBlock / WoPerThread;
static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster,
"wrong! wrong blocksize\n");
}
__device__ static constexpr auto GetCThreadDesc_K_N_Ho_WoLengths()
{
return Sequence<KPerThread, I1, HoPerThread, WoPerThread>{};
}
__device__ static CIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id)
{
constexpr auto K0 = KPerBlock / KPerThread;
constexpr auto N0 = I1;
constexpr auto H0 = HoPerBlock / HoPerThread;
constexpr auto W0 = WoPerBlock / WoPerThread;
constexpr auto c_threadid_to_k_n_h_w_thread_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))),
make_tuple(Sequence<0, 1, 2, 3>{}),
make_tuple(Sequence<0>{}));
const auto c_k_n_h_w_thread_cluster_idx =
c_threadid_to_k_n_h_w_thread_cluster_adaptor.CalculateBottomIndex(
make_multi_index(thread_id));
return c_k_n_h_w_thread_cluster_idx;
}
template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BThreadBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(
is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{};
// thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
FloatB,
FloatC,
decltype(a_thread_mtx_),
decltype(b_thread_mtx_),
decltype(c_thread_mtx_)>{};
static_for<0, E1, EPerThreadLoop>{}([&](auto e_begin) {
static_for<0, KPerThread, KPerThreadLoop>{}([&](auto k_begin) {
a_thread_copy_.Run(a_block_mtx,
make_tuple(e_begin, k_begin, I0),
a_block_buf,
a_thread_mtx_,
make_tuple(I0, I0, I0),
a_thread_buf);
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(e_begin, I0, I0, I0, I0),
c_thread_buf,
make_tuple(k_begin, I0, I0, I0));
});
});
}
template <typename ABlockSliceMoveStepIdx>
__device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
{
a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K1_E2{}, a_block_slice_move_step_idx);
}
private:
using AThreadCopy =
ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
ABlockDesc_E1_K1_E2,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadLoop, E2>,
Sequence<0, 1, 2>,
2,
E2,
E2>;
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,340 @@
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp"
#include "tensor_adaptor.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t WaveSize = 64;
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock =
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{};
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
StaticBufferTupleOfVector<AddressSpaceEnum_t::Vgpr,
FloatAcc,
MRepeat * NRepeat,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_;
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
{
const index_t thread_id = get_thread_local_1d_id();
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], Number<KPack>{} * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], Number<KPack>{} * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_block_desc_g_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
}
template <typename CGridDesc_G_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
{
const auto G = c_grid_desc_g_m_n.GetLength(I0);
const auto M = c_grid_desc_g_m_n.GetLength(I1);
const auto N = c_grid_desc_g_m_n.GetLength(I2);
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
c_grid_desc_g_m_n,
make_tuple(make_pass_through_transform(G),
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
{
return transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
static_for<0, KPerBlock, KPack * xdlops_gemm.K0PerXdlops>{}([&](auto k) {
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
});
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.template Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
private:
// A[M0, M1, M2, KPerBlock]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerBlock>{}));
// B[N0, N1, N2, KPerBlock]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerBlock>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerBlock>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerBlock>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
};
} // namespace ck
#endif

View File

@@ -0,0 +1,172 @@
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v3r1.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename SrcElementwiseOperation,
typename DstElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun,
index_t NumThreadScratch = 1>
struct BlockwiseTensorSliceTransfer_v4r1
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v4r1(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const DstElementwiseOperation& dst_element_op)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
src_element_op,
dst_desc,
make_zero_multi_index<nDim>(),
dst_element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
}
}
template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
}
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id)
{
RunRead(src_desc, src_buf, thread_scratch_id);
RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r1<decltype(thread_slice_lengths),
SrcElementwiseOperation,
DstElementwiseOperation,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector,
SrcScalarStrideInVector,
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,156 @@
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v5r1.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
typename SrcVectorTensorLengths,
typename DstVectorTensorLengths,
typename SrcVectorTensorContiguousDimOrder,
typename DstVectorTensorContiguousDimOrder,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v5r1
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin)
: threadwise_transfer_(
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{};
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, typename SrcStepHacks>
__device__ void
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
}
}
template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf);
}
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
// SrcMoveSliceWindowStepHack to control index calculation move slice window
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_step_hack);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v5r1<ThreadSliceLengths,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorTensorLengths,
DstVectorTensorLengths,
SrcVectorTensorContiguousDimOrder,
DstVectorTensorContiguousDimOrder,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,133 @@
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v6r1.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r1
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src_desc,
make_zero_multi_index<nDim>(),
dst_desc,
make_zero_multi_index<nDim>(),
element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf);
}
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v6r1<SrcData,
DstData,
SrcDesc,
DstDesc,
ElementwiseOperation,
decltype(thread_slice_lengths),
DimAccessOrder,
VectorDim,
ScalarPerVector,
DstInMemOp,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,157 @@
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v6r2.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. It does not keep reference to tensor descriptor
// 3. Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename Src0Data,
typename Src1Data,
typename DstData,
typename Src0Desc,
typename Src1Desc,
typename DstDesc,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool ThreadTransferSrc0ResetCoordinateAfterRun,
bool ThreadTransferSrc1ResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r2
{
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin,
const Src1Desc& src1_desc,
const Index& src1_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src0_desc,
make_zero_multi_index<nDim>(),
src1_desc,
make_zero_multi_index<nDim>(),
dst_desc,
make_zero_multi_index<nDim>(),
element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrc0SliceOrigin(
src0_desc, src0_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetSrc1SliceOrigin(
src1_desc, src1_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename Src0Buffer, typename Src1Buffer, typename DstBuffer>
__device__ void Run(const Src0Desc& src0_desc,
const Src0Buffer& src0_buf,
const Src1Desc& src1_desc,
const Src1Buffer& src1_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf);
}
}
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
}
}
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v6r2<Src0Data,
Src1Data,
DstData,
Src0Desc,
Src1Desc,
DstDesc,
ElementwiseOperation,
decltype(thread_slice_lengths),
DimAccessOrder,
VectorDim,
ScalarPerVector,
DstInMemOp,
ThreadTransferSrc0ResetCoordinateAfterRun,
ThreadTransferSrc1ResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,182 @@
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v6r3.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
typename ElementwiseOperation,
InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename Src0Data,
typename Src1Data,
typename Src2Data,
typename DstData,
typename Src0Desc,
typename Src1Desc,
typename Src2Desc,
typename DstDesc,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool ThreadTransferSrc0ResetCoordinateAfterRun,
bool ThreadTransferSrc1ResetCoordinateAfterRun,
bool ThreadTransferSrc2ResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v6r3
{
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
const Index& src0_block_slice_origin,
const Src1Desc& src1_desc,
const Index& src1_block_slice_origin,
const Src2Desc& src2_desc,
const Index& src2_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src0_desc,
make_zero_multi_index<nDim>(),
src1_desc,
make_zero_multi_index<nDim>(),
src2_desc,
make_zero_multi_index<nDim>(),
dst_desc,
make_zero_multi_index<nDim>(),
element_op)
{
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<Src2Desc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == DimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrc0SliceOrigin(
src0_desc, src0_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetSrc1SliceOrigin(
src1_desc, src1_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetSrc2SliceOrigin(
src2_desc, src2_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename Src0Buffer, typename Src1Buffer, typename Src2Buffer, typename DstBuffer>
__device__ void Run(const Src0Desc& src0_desc,
const Src0Buffer& src0_buf,
const Src1Desc& src1_desc,
const Src1Buffer& src1_buf,
const Src2Desc& src2_desc,
const Src2Buffer& src2_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.Run(
src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf);
}
}
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
}
}
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
}
}
__device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v6r3<Src0Data,
Src1Data,
Src2Data,
DstData,
Src0Desc,
Src1Desc,
Src2Desc,
DstDesc,
ElementwiseOperation,
decltype(thread_slice_lengths),
DimAccessOrder,
VectorDim,
ScalarPerVector,
DstInMemOp,
ThreadTransferSrc0ResetCoordinateAfterRun,
ThreadTransferSrc1ResetCoordinateAfterRun,
ThreadTransferSrc2ResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,185 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"
namespace ck {
template <typename Buffer1dDescType,
typename AccDataType,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
bool ReorderThreadClusters,
typename OpReduce,
bool PropagateNan>
struct PartitionedBlockwiseReductionOn1dBuffer
{
static constexpr auto buffer_1d_desc = Buffer1dDescType{};
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"The product of cluster lengths should be same as BlockSize!");
static_assert(KThreadClusterSize > 1, "Parallel reduction need work on at least two elements");
static_assert(buffer_1d_desc.GetElementSize() == BlockSize,
"The buffer size should be the same as BlockSize!");
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
template <typename BufferType>
__device__ static void Reduce(BufferType& block_buffer,
AccDataType& accuData,
index_t thread_m_cluster_id,
index_t thread_k_cluster_id)
{
constexpr auto cluster_len_shift = get_shift<KThreadClusterSize>();
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
if(thread_k_cluster_id < indOffset)
{
// consider the thread clusters order, ensure the contiguous locations are accessed
// by contiguous Thread-ID
index_t offset1 =
ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(
thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(make_tuple(
thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id));
index_t offset2 = ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(
(thread_k_cluster_id + indOffset) * MThreadClusterSize +
thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(
make_tuple(thread_m_cluster_id * KThreadClusterSize +
(thread_k_cluster_id + indOffset)));
AccDataType opData1 = type_convert<AccDataType>(block_buffer[offset1]);
AccDataType opData2 = type_convert<AccDataType>(block_buffer[offset2]);
Accumulation::Calculate(opData1, opData2);
block_buffer(offset1) = type_convert<AccDataType>(opData1);
}
__syncthreads();
});
index_t offset = ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(
make_tuple(thread_m_cluster_id * KThreadClusterSize));
accuData = type_convert<AccDataType>(block_buffer[offset]);
};
};
template <typename Buffer1dDescType,
typename AccDataType,
typename IndexDataType,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
bool ReorderThreadClusters,
typename OpReduce,
bool PropagateNan>
struct PartitionedBlockwiseReductionWithIndexOn1dBuffer
{
static constexpr auto buffer_1d_desc = Buffer1dDescType{};
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
"The product of cluster lengths should be same as BlockSize!");
static_assert(KThreadClusterSize > 1, "Parallel reduction need work on at least two elements");
static_assert(buffer_1d_desc.GetElementSize() == BlockSize,
"The buffer size should be the same as BlockSize!");
using Accumulation =
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
// This interface accumulates on both data values and indices
template <typename BufferType, typename IdxBufferType>
__device__ static void Reduce(BufferType& block_val_buffer,
IdxBufferType& block_idx_buffer,
AccDataType& accuData,
IndexDataType& accuIndex,
index_t thread_m_cluster_id,
index_t thread_k_cluster_id)
{
constexpr auto cluster_len_shift = get_shift<KThreadClusterSize>();
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << I();
if(thread_k_cluster_id % (indOffset * 2) == 0)
{
// consider the thread clusters order, ensure the contiguous locations are accessed
// by contiguous Thread-ID
index_t offset1 =
ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(
thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(make_tuple(
thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id));
index_t offset2 = ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(
(thread_k_cluster_id + indOffset) * MThreadClusterSize +
thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(
make_tuple(thread_m_cluster_id * KThreadClusterSize +
(thread_k_cluster_id + indOffset)));
AccDataType opData1 = type_convert<AccDataType>(block_val_buffer[offset1]);
AccDataType opData2 = type_convert<AccDataType>(block_val_buffer[offset2]);
IndexDataType currIndex1 = block_idx_buffer[offset1];
IndexDataType currIndex2 = block_idx_buffer[offset2];
Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2);
block_val_buffer(offset1) = type_convert<AccDataType>(opData1);
block_idx_buffer(offset1) = currIndex1;
}
__syncthreads();
});
index_t offset = ReorderThreadClusters
? buffer_1d_desc.CalculateOffset(make_tuple(thread_m_cluster_id))
: buffer_1d_desc.CalculateOffset(
make_tuple(thread_m_cluster_id * KThreadClusterSize));
accuData = type_convert<AccDataType>(block_val_buffer[offset]);
accuIndex = block_idx_buffer[offset];
}
};
}; // end of namespace ck
#endif