mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Reorganize files, Part 1 (#119)
* delete obselete files * move files * build * update cmake * update cmake * fix build * reorg examples * update cmake for example and test
This commit is contained in:
@@ -0,0 +1,394 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
|
||||
#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_contraction_dlops.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
|
||||
// A and B are visable to the whole block, C is distributed among each thread
|
||||
// Assume:
|
||||
// 1. A:
|
||||
// 1. AKMBlockDesc is known at compile-time
|
||||
// 2. ABlockBuffer is DynamicBuffer
|
||||
// 2. B:
|
||||
// 1. BKNBlockDesc is known at compile-time
|
||||
// 2. BBlockBuffer is DynamicBuffer
|
||||
// 3. C:
|
||||
// 1. CM0M1N0N1ThreadDesc is known at compile-time
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
// Also assume:
|
||||
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
|
||||
template <
|
||||
index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AKMBlockDesc,
|
||||
typename BKNBlockDesc,
|
||||
index_t M1PerThreadM11,
|
||||
index_t N1PerThreadN11,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM100,
|
||||
index_t M1N1ThreadClusterN100,
|
||||
index_t M1N1ThreadClusterM101,
|
||||
index_t M1N1ThreadClusterN101,
|
||||
index_t AThreadCopyScalarPerVector_M11,
|
||||
index_t BThreadCopyScalarPerVector_N11,
|
||||
typename enable_if<AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
using CIndex = MultiIndex<4>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr index_t K = AKMBlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t M = AKMBlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t N = BKNBlockDesc{}.GetLength(I1);
|
||||
|
||||
static constexpr index_t M100 = M1N1ThreadClusterM100;
|
||||
static constexpr index_t N100 = M1N1ThreadClusterN100;
|
||||
|
||||
static constexpr index_t M101 = M1N1ThreadClusterM101;
|
||||
static constexpr index_t N101 = M1N1ThreadClusterN101;
|
||||
|
||||
static constexpr index_t M11 = M1PerThreadM11;
|
||||
static constexpr index_t N11 = N1PerThreadN11;
|
||||
|
||||
static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11;
|
||||
static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11;
|
||||
|
||||
static constexpr index_t M0 = M / M1;
|
||||
static constexpr index_t N0 = N / N1;
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */)
|
||||
{
|
||||
const auto a_k_m0_m1_block_desc = transform_tensor_descriptor(
|
||||
AKMBlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<K>{}),
|
||||
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
return a_k_m0_m1_block_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */)
|
||||
{
|
||||
const auto b_k_n0_n1_block_desc = transform_tensor_descriptor(
|
||||
BKNBlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<K>{}),
|
||||
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
return b_k_n0_n1_block_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
|
||||
{
|
||||
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
// lower: [M, N]
|
||||
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
Number<M0>{}, Number<M100>{}, Number<M101>{}, Number<M11>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<N0>{}, Number<N100>{}, Number<N101>{}, Number<N11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
|
||||
|
||||
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor()
|
||||
{
|
||||
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
// lower: [M0, M1, N0, N1]
|
||||
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_pass_through_transform(Number<M0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<M100>{}, Number<M101>{}, Number<M11>{})),
|
||||
make_pass_through_transform(Number<N0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<N100>{}, Number<N101>{}, Number<N11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
|
||||
|
||||
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths()
|
||||
{
|
||||
return Sequence<M0, M11, N0, N11>{};
|
||||
}
|
||||
|
||||
static constexpr auto a_k_m0_m1_block_desc_ = MakeAKM0M1BlockDescriptor(AKMBlockDesc{});
|
||||
static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{});
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2()
|
||||
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
|
||||
get_thread_local_1d_id())},
|
||||
a_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
|
||||
b_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
|
||||
{
|
||||
static_assert(AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BlockSize == M101 * M100 * N101 * N100,
|
||||
"wrong! blocksize and cluster size not consistent");
|
||||
|
||||
static_assert(M % M1 == 0 && N % N1 == 0, "wrong!");
|
||||
|
||||
static_assert(AKMBlockDesc{}.GetLength(I0) == BKNBlockDesc{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(M0 == 2 && N0 == 2, "wrong");
|
||||
}
|
||||
|
||||
__device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id)
|
||||
{
|
||||
// lower: [M0, M1, N0, N1]
|
||||
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor();
|
||||
|
||||
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
// upper: [Tid, M0, M11, N0, N11]
|
||||
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)),
|
||||
make_pass_through_transform(M0),
|
||||
make_pass_through_transform(M11),
|
||||
make_pass_through_transform(N0),
|
||||
make_pass_through_transform(N11)),
|
||||
make_tuple(
|
||||
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
|
||||
|
||||
return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockAlignment() { return N1PerThreadN11; }
|
||||
|
||||
template <typename CM0M1N0N1ThreadDesc,
|
||||
typename ABlockBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const CM0M1N0N1ThreadDesc& /* c_m0_m1_n0_n1_thread_desc */,
|
||||
const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 &&
|
||||
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
|
||||
"wrong");
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>(
|
||||
a_k_m0_m1_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>(
|
||||
b_k_n0_n1_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_k_m0_m1_thread_desc_),
|
||||
decltype(b_k_n0_n1_thread_desc_),
|
||||
CM0M1N0N1ThreadDesc,
|
||||
Sequence<KPerThread>,
|
||||
Sequence<1, M1PerThreadM11>,
|
||||
Sequence<1, N1PerThreadN11>>{};
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_k_m0_m1_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_k_n0_n1_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_block_buf,
|
||||
b_k_n0_n1_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
a_block_buf,
|
||||
a_k_m0_m1_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
|
||||
// loop over rest of k
|
||||
static_for<KPerThread, K, KPerThread>{}([&](auto k) {
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
|
||||
make_tuple(k, I0, I0),
|
||||
a_block_buf,
|
||||
a_k_m0_m1_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
|
||||
make_tuple(k, I0, I0),
|
||||
b_block_buf,
|
||||
b_k_n0_n1_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
|
||||
make_tuple(k, I1, I0),
|
||||
b_block_buf,
|
||||
b_k_n0_n1_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
|
||||
make_tuple(k, I1, I0),
|
||||
a_block_buf,
|
||||
a_k_m0_m1_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
});
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
}
|
||||
|
||||
private:
|
||||
// A[K, M0, M1]
|
||||
static constexpr auto a_k_m0_m1_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}));
|
||||
|
||||
// B[K, N0, N1]
|
||||
static constexpr auto b_k_n0_n1_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
decltype(a_k_m0_m1_block_desc_),
|
||||
decltype(a_k_m0_m1_thread_desc_),
|
||||
Sequence<KPerThread, 1, M1PerThreadM11>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
AThreadCopyScalarPerVector_M11,
|
||||
1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
decltype(b_k_n0_n1_block_desc_),
|
||||
decltype(b_k_n0_n1_thread_desc_),
|
||||
Sequence<KPerThread, 1, N1PerThreadN11>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
BThreadCopyScalarPerVector_N11,
|
||||
1>;
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,410 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
|
||||
#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_contraction_dlops.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
|
||||
// A and B are visable to the whole block, C is distributed among each thread
|
||||
// Assume:
|
||||
// 1. A:
|
||||
// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
|
||||
// 2. ABlockBuffer is DynamicBuffer
|
||||
// 2. B:
|
||||
// 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
|
||||
// 2. BBlockBuffer is DynamicBuffer
|
||||
// 3. C:
|
||||
// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
// Also assume:
|
||||
// BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
|
||||
// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ABlockDesc_BK0_BM_BK1,
|
||||
typename BBlockDesc_BK0_BN_BK1,
|
||||
index_t BM1PerThreadBM11,
|
||||
index_t BN1PerThreadBN11,
|
||||
index_t BK0PerThread,
|
||||
typename BM10BN10ThreadClusterBM10Xs, // Sequence<BM10BN10ThreadClusterBM100,
|
||||
// BM10BN10ThreadClusterBM101, ...>
|
||||
typename BM10BN10ThreadClusterBN10Xs, // Sequence<BM10BN10ThreadClusterBN100,
|
||||
// BM10BN10ThreadClusterBN101, ...>
|
||||
index_t AThreadCopyScalarPerVector_BM11,
|
||||
index_t BThreadCopyScalarPerVector_BN11,
|
||||
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
using CIndex = MultiIndex<4>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0);
|
||||
static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2);
|
||||
static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
|
||||
static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
|
||||
|
||||
static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0];
|
||||
static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0];
|
||||
|
||||
static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1];
|
||||
static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1];
|
||||
|
||||
static constexpr index_t BM11 = BM1PerThreadBM11;
|
||||
static constexpr index_t BN11 = BN1PerThreadBN11;
|
||||
|
||||
static constexpr index_t BM1 = BM100 * BM101 * BM11;
|
||||
static constexpr index_t BN1 = BN100 * BN101 * BN11;
|
||||
|
||||
static constexpr index_t BM0 = BM / BM1;
|
||||
static constexpr index_t BN0 = BN / BN1;
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
|
||||
{
|
||||
const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
|
||||
a_block_desc_bk0_bm_bk1,
|
||||
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
|
||||
make_pass_through_transform(Number<BK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return a_block_bk0_bm0_bm1_bk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
|
||||
{
|
||||
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
|
||||
b_block_desc_bk0_bn_bk1,
|
||||
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
|
||||
make_pass_through_transform(Number<BK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return b_block_desc_bk0_bn0_bn1_bk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN()
|
||||
{
|
||||
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
// lower: [BM, BN]
|
||||
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
Number<BM0>{}, Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<BN0>{}, Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
|
||||
|
||||
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1()
|
||||
{
|
||||
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
// lower: [BM0, BM1, BN0, BN1]
|
||||
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_pass_through_transform(Number<BM0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
|
||||
make_pass_through_transform(Number<BN0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
|
||||
|
||||
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
|
||||
{
|
||||
return Sequence<BM0, BM11, BN0, BN11>{};
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ =
|
||||
MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{});
|
||||
|
||||
static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ =
|
||||
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
|
||||
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id())},
|
||||
a_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)},
|
||||
b_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)}
|
||||
{
|
||||
static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BlockSize == BM101 * BM100 * BN101 * BN100,
|
||||
"wrong! blocksize and cluster size not consistent");
|
||||
|
||||
static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!");
|
||||
|
||||
static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) ==
|
||||
BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
// TODO remove this restriction
|
||||
static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 &&
|
||||
BM10BN10ThreadClusterBN10Xs::Size() == 2,
|
||||
"wrong!");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(BM0 == 2 && BN0 == 2, "wrong");
|
||||
}
|
||||
|
||||
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
|
||||
{
|
||||
// lower: [BM0, BM1, BN0, BN1]
|
||||
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
constexpr auto adaptor0 =
|
||||
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1();
|
||||
|
||||
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
// upper: [Tid, BM0, BM11, BN0, BN11]
|
||||
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)),
|
||||
make_pass_through_transform(BM0),
|
||||
make_pass_through_transform(BM11),
|
||||
make_pass_through_transform(BN0),
|
||||
make_pass_through_transform(BN11)),
|
||||
make_tuple(
|
||||
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
|
||||
|
||||
return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
|
||||
}
|
||||
|
||||
template <typename CThreadDesc_BM0_BM11_BN0_BN11,
|
||||
typename ABlockBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&,
|
||||
const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(BM0 == 2 && BN0 == 2 &&
|
||||
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 &&
|
||||
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
|
||||
"wrong");
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>(
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>(
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_contraction =
|
||||
ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
|
||||
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
|
||||
CThreadDesc_BM0_BM11_BN0_BN11,
|
||||
Sequence<BK0PerThread, BK1>,
|
||||
Sequence<1, BM1PerThreadBM11>,
|
||||
Sequence<1, BN1PerThreadBN11>>{};
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
|
||||
// loop over rest of bk0
|
||||
static_for<BK0PerThread, BK0, BK0PerThread>{}([&](auto bk0) {
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(bk0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(bk0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(bk0, I1, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(bk0, I1, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
});
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
}
|
||||
|
||||
private:
|
||||
// A[BK0, BM0, BM1, BK1]
|
||||
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{}));
|
||||
|
||||
// B[BK0, BN0, BN1, BK1]
|
||||
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
|
||||
FloatA,
|
||||
FloatA,
|
||||
decltype(a_block_desc_bk0_bm0_bm1_bk1_),
|
||||
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
|
||||
Sequence<BK0PerThread, 1, BM1PerThreadBM11, BK1>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths
|
||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
|
||||
FloatB,
|
||||
FloatB,
|
||||
decltype(b_block_desc_bk0_bn0_bn1_bk1_),
|
||||
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
|
||||
Sequence<BK0PerThread, 1, BN1PerThreadBN11, BK1>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
Sequence<1, 1, BN1PerThreadBN11, BK1>, // SrcVectorTensorLengths
|
||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,175 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
|
||||
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "threadwise_gemm_dlops_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ABlockDesc_E1_K1_E2,
|
||||
typename BBlockDesc_E1_N_Ho_Wo_E2,
|
||||
typename CThreadDesc_K_N_Ho_Wo,
|
||||
index_t EPerThreadLoop,
|
||||
index_t KPerThreadLoop>
|
||||
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
using CIndex = MultiIndex<4>;
|
||||
|
||||
static constexpr auto E1 = ABlockDesc_E1_K1_E2{}.GetLength(I0);
|
||||
static constexpr auto KPerBlock = ABlockDesc_E1_K1_E2{}.GetLength(I1);
|
||||
static constexpr auto E2 = ABlockDesc_E1_K1_E2{}.GetLength(I2);
|
||||
|
||||
static constexpr auto HoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
|
||||
static constexpr auto WoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
|
||||
|
||||
static constexpr auto KPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I0);
|
||||
static constexpr auto HoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2);
|
||||
static constexpr auto WoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3);
|
||||
|
||||
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadLoop>{}, Number<E2>{}));
|
||||
|
||||
static constexpr auto b_thread_mtx_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<EPerThreadLoop>{},
|
||||
Number<1>{},
|
||||
Number<HoPerThread>{},
|
||||
Number<WoPerThread>{},
|
||||
Number<E2>{}));
|
||||
|
||||
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<KPerThreadLoop>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
||||
|
||||
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
|
||||
: c_thread_origin_data_idx_{GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id())},
|
||||
a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * KPerThread, 0)}
|
||||
{
|
||||
static_assert(ABlockDesc_E1_K1_E2::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(
|
||||
ABlockDesc_E1_K1_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) &&
|
||||
ABlockDesc_E1_K1_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4),
|
||||
"wrong! E dimension not consistent\n");
|
||||
|
||||
static_assert(E1 % EPerThreadLoop == 0, "");
|
||||
static_assert(KPerThread % KPerThreadLoop == 0, "");
|
||||
|
||||
static_assert(KPerBlock % KPerThread == 0 && HoPerBlock % HoPerThread == 0 &&
|
||||
WoPerBlock % WoPerThread == 0,
|
||||
"wrong! Cannot evenly divide work among\n");
|
||||
|
||||
constexpr auto KThreadCluster = KPerBlock / KPerThread;
|
||||
constexpr auto HThreadCluster = HoPerBlock / HoPerThread;
|
||||
constexpr auto WThreadCluster = WoPerBlock / WoPerThread;
|
||||
|
||||
static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster,
|
||||
"wrong! wrong blocksize\n");
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetCThreadDesc_K_N_Ho_WoLengths()
|
||||
{
|
||||
return Sequence<KPerThread, I1, HoPerThread, WoPerThread>{};
|
||||
}
|
||||
|
||||
__device__ static CIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id)
|
||||
{
|
||||
constexpr auto K0 = KPerBlock / KPerThread;
|
||||
constexpr auto N0 = I1;
|
||||
constexpr auto H0 = HoPerBlock / HoPerThread;
|
||||
constexpr auto W0 = WoPerBlock / WoPerThread;
|
||||
|
||||
constexpr auto c_threadid_to_k_n_h_w_thread_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto c_k_n_h_w_thread_cluster_idx =
|
||||
c_threadid_to_k_n_h_w_thread_cluster_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(thread_id));
|
||||
|
||||
return c_k_n_h_w_thread_cluster_idx;
|
||||
}
|
||||
|
||||
template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BThreadBuffer& b_thread_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename ABlockBuffer::type>, remove_cvref_t<FloatA>>::value &&
|
||||
is_same<remove_cvref_t<typename BThreadBuffer::type>, remove_cvref_t<FloatB>>::value &&
|
||||
is_same<remove_cvref_t<typename CThreadBuffer::type>, remove_cvref_t<FloatC>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{};
|
||||
|
||||
// thread A buffer for GEMM
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
|
||||
a_thread_buf;
|
||||
|
||||
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_mtx_),
|
||||
decltype(b_thread_mtx_),
|
||||
decltype(c_thread_mtx_)>{};
|
||||
|
||||
static_for<0, E1, EPerThreadLoop>{}([&](auto e_begin) {
|
||||
static_for<0, KPerThread, KPerThreadLoop>{}([&](auto k_begin) {
|
||||
a_thread_copy_.Run(a_block_mtx,
|
||||
make_tuple(e_begin, k_begin, I0),
|
||||
a_block_buf,
|
||||
a_thread_mtx_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(e_begin, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(k_begin, I0, I0, I0));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ABlockSliceMoveStepIdx>
|
||||
__device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
|
||||
{
|
||||
a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K1_E2{}, a_block_slice_move_step_idx);
|
||||
}
|
||||
|
||||
private:
|
||||
using AThreadCopy =
|
||||
ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
ABlockDesc_E1_K1_E2,
|
||||
decltype(a_thread_mtx_),
|
||||
Sequence<EPerThreadLoop, KPerThreadLoop, E2>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
E2,
|
||||
E2>;
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
340
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
Normal file
340
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
Normal file
@@ -0,0 +1,340 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
|
||||
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "xdlops_gemm.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename AK0MK1BlockDesc,
|
||||
typename BK0NK1BlockDesc,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr index_t WaveSize = 64;
|
||||
|
||||
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t KPerBlock =
|
||||
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
|
||||
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{};
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum_t::Vgpr,
|
||||
FloatAcc,
|
||||
MRepeat * NRepeat,
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
true>
|
||||
c_thread_buf_;
|
||||
|
||||
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
||||
|
||||
__device__ static auto GetWaveIdx()
|
||||
{
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
}
|
||||
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
|
||||
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_m, xdlops_a_idx[I1], Number<KPack>{} * xdlops_a_idx[I0]);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_n, xdlops_b_idx[I1], Number<KPack>{} * xdlops_b_idx[I0]);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
|
||||
|
||||
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}));
|
||||
|
||||
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
|
||||
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
|
||||
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
|
||||
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
|
||||
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
|
||||
{
|
||||
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
|
||||
BK0NK1BlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BlockSize == MWaves * NWaves * WaveSize,
|
||||
"BlockSize != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
|
||||
|
||||
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
|
||||
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
|
||||
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
|
||||
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
|
||||
|
||||
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
|
||||
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
|
||||
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
|
||||
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerXDL>{},
|
||||
Number<NPerXDL>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
|
||||
{
|
||||
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1,
|
||||
Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<MWaves>{},
|
||||
Number<NWaves>{},
|
||||
Number<MPerXDL>{},
|
||||
Number<NPerXDL>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
c_block_desc_g_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
|
||||
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
template <typename CGridDesc_G_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
|
||||
{
|
||||
const auto G = c_grid_desc_g_m_n.GetLength(I0);
|
||||
const auto M = c_grid_desc_g_m_n.GetLength(I1);
|
||||
const auto N = c_grid_desc_g_m_n.GetLength(I2);
|
||||
|
||||
const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
|
||||
c_grid_desc_g_m_n,
|
||||
make_tuple(make_pass_through_transform(G),
|
||||
make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
|
||||
make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{}));
|
||||
|
||||
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
AK0MK1BlockDesc{},
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
|
||||
{
|
||||
return transform_tensor_descriptor(
|
||||
BK0NK1BlockDesc{},
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
|
||||
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
static_for<0, KPerBlock, KPack * xdlops_gemm.K0PerXdlops>{}([&](auto k) {
|
||||
vector_type<FloatAB, KPack> a_thread_vec;
|
||||
vector_type<FloatAB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
|
||||
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
|
||||
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
|
||||
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
// A[M0, M1, M2, KPerBlock]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerBlock>{}));
|
||||
|
||||
// B[N0, N1, N2, KPerBlock]
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerBlock>{}));
|
||||
|
||||
// C[M, N, NumRegXdlops]
|
||||
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_block_desc_m0_m1_m2_k),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerBlock>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_block_desc_n0_n1_n2_k),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, KPerBlock>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
|
||||
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,172 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V4R1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v3r1.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this version does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
typename SrcElementwiseOperation,
|
||||
typename DstElementwiseOperation,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun,
|
||||
index_t NumThreadScratch = 1>
|
||||
struct BlockwiseTensorSliceTransfer_v4r1
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v4r1(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const SrcElementwiseOperation& src_element_op,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const DstElementwiseOperation& dst_element_op)
|
||||
: threadwise_transfer_(src_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src_element_op,
|
||||
dst_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
dst_element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer, index_t ThreadScratchId = 0>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
Number<ThreadScratchId> thread_scratch_id)
|
||||
{
|
||||
RunRead(src_desc, src_buf, thread_scratch_id);
|
||||
RunWrite(dst_desc, dst_buf, thread_scratch_id);
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v3r1<decltype(thread_slice_lengths),
|
||||
SrcElementwiseOperation,
|
||||
DstElementwiseOperation,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstScalarPerVector,
|
||||
SrcScalarStrideInVector,
|
||||
DstScalarStrideInVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun,
|
||||
NumThreadScratch>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,156 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v5r1.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this version does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
typename SrcVectorTensorLengths,
|
||||
typename DstVectorTensorLengths,
|
||||
typename SrcVectorTensorContiguousDimOrder,
|
||||
typename DstVectorTensorContiguousDimOrder,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseTensorSliceTransfer_v5r1
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin)
|
||||
: threadwise_transfer_(
|
||||
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
|
||||
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{};
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcStepHacks>
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunWrite(dst_desc, dst_buf);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
// SrcMoveSliceWindowStepHack to control index calculation move slice window
|
||||
template <typename SrcMoveSliceWindowStepHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& step,
|
||||
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(
|
||||
src_desc, step, src_move_slice_window_step_hack);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v5r1<ThreadSliceLengths,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorTensorLengths,
|
||||
DstVectorTensorLengths,
|
||||
SrcVectorTensorContiguousDimOrder,
|
||||
DstVectorTensorContiguousDimOrder,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,133 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v6r1.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this version does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorDim,
|
||||
index_t ScalarPerVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseTensorSliceTransfer_v6r1
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
dst_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == DimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v6r1<SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
ElementwiseOperation,
|
||||
decltype(thread_slice_lengths),
|
||||
DimAccessOrder,
|
||||
VectorDim,
|
||||
ScalarPerVector,
|
||||
DstInMemOp,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,157 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v6r2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this version does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. It does not keep reference to tensor descriptor
|
||||
// 3. Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename Src0Data,
|
||||
typename Src1Data,
|
||||
typename DstData,
|
||||
typename Src0Desc,
|
||||
typename Src1Desc,
|
||||
typename DstDesc,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorDim,
|
||||
index_t ScalarPerVector,
|
||||
bool ThreadTransferSrc0ResetCoordinateAfterRun,
|
||||
bool ThreadTransferSrc1ResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseTensorSliceTransfer_v6r2
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
|
||||
const Index& src0_block_slice_origin,
|
||||
const Src1Desc& src1_desc,
|
||||
const Index& src1_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src0_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src1_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
dst_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == DimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetSrc0SliceOrigin(
|
||||
src0_desc, src0_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetSrc1SliceOrigin(
|
||||
src1_desc, src1_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Src0Buffer, typename Src1Buffer, typename DstBuffer>
|
||||
__device__ void Run(const Src0Desc& src0_desc,
|
||||
const Src0Buffer& src0_buf,
|
||||
const Src1Desc& src1_desc,
|
||||
const Src1Buffer& src1_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v6r2<Src0Data,
|
||||
Src1Data,
|
||||
DstData,
|
||||
Src0Desc,
|
||||
Src1Desc,
|
||||
DstDesc,
|
||||
ElementwiseOperation,
|
||||
decltype(thread_slice_lengths),
|
||||
DimAccessOrder,
|
||||
VectorDim,
|
||||
ScalarPerVector,
|
||||
DstInMemOp,
|
||||
ThreadTransferSrc0ResetCoordinateAfterRun,
|
||||
ThreadTransferSrc1ResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,182 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v6r3.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this version does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
typename ElementwiseOperation,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename Src0Data,
|
||||
typename Src1Data,
|
||||
typename Src2Data,
|
||||
typename DstData,
|
||||
typename Src0Desc,
|
||||
typename Src1Desc,
|
||||
typename Src2Desc,
|
||||
typename DstDesc,
|
||||
typename DimAccessOrder,
|
||||
index_t VectorDim,
|
||||
index_t ScalarPerVector,
|
||||
bool ThreadTransferSrc0ResetCoordinateAfterRun,
|
||||
bool ThreadTransferSrc1ResetCoordinateAfterRun,
|
||||
bool ThreadTransferSrc2ResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseTensorSliceTransfer_v6r3
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<Src0Desc>::GetNumOfDimension();
|
||||
|
||||
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
|
||||
const Index& src0_block_slice_origin,
|
||||
const Src1Desc& src1_desc,
|
||||
const Index& src1_block_slice_origin,
|
||||
const Src2Desc& src2_desc,
|
||||
const Index& src2_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const ElementwiseOperation& element_op)
|
||||
: threadwise_transfer_(src0_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src1_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
src2_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
dst_desc,
|
||||
make_zero_multi_index<nDim>(),
|
||||
element_op)
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_reference_t<remove_cv_t<Src0Desc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<Src1Desc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<Src2Desc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == DimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
|
||||
|
||||
threadwise_transfer_.SetSrc0SliceOrigin(
|
||||
src0_desc, src0_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetSrc1SliceOrigin(
|
||||
src1_desc, src1_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetSrc2SliceOrigin(
|
||||
src2_desc, src2_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Src0Buffer, typename Src1Buffer, typename Src2Buffer, typename DstBuffer>
|
||||
__device__ void Run(const Src0Desc& src0_desc,
|
||||
const Src0Buffer& src0_buf,
|
||||
const Src1Desc& src1_desc,
|
||||
const Src1Buffer& src1_buf,
|
||||
const Src2Desc& src2_desc,
|
||||
const Src2Buffer& src2_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.Run(
|
||||
src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v6r3<Src0Data,
|
||||
Src1Data,
|
||||
Src2Data,
|
||||
DstData,
|
||||
Src0Desc,
|
||||
Src1Desc,
|
||||
Src2Desc,
|
||||
DstDesc,
|
||||
ElementwiseOperation,
|
||||
decltype(thread_slice_lengths),
|
||||
DimAccessOrder,
|
||||
VectorDim,
|
||||
ScalarPerVector,
|
||||
DstInMemOp,
|
||||
ThreadTransferSrc0ResetCoordinateAfterRun,
|
||||
ThreadTransferSrc1ResetCoordinateAfterRun,
|
||||
ThreadTransferSrc2ResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,185 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2020 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
|
||||
#define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
#include "reduction_common.hpp"
|
||||
#include "reduction_operator.hpp"
|
||||
#include "reduction_functions_accumulate.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename Buffer1dDescType,
|
||||
typename AccDataType,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
bool ReorderThreadClusters,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
struct PartitionedBlockwiseReductionOn1dBuffer
|
||||
{
|
||||
static constexpr auto buffer_1d_desc = Buffer1dDescType{};
|
||||
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"The product of cluster lengths should be same as BlockSize!");
|
||||
static_assert(KThreadClusterSize > 1, "Parallel reduction need work on at least two elements");
|
||||
|
||||
static_assert(buffer_1d_desc.GetElementSize() == BlockSize,
|
||||
"The buffer size should be the same as BlockSize!");
|
||||
|
||||
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
|
||||
|
||||
template <typename BufferType>
|
||||
__device__ static void Reduce(BufferType& block_buffer,
|
||||
AccDataType& accuData,
|
||||
index_t thread_m_cluster_id,
|
||||
index_t thread_k_cluster_id)
|
||||
{
|
||||
constexpr auto cluster_len_shift = get_shift<KThreadClusterSize>();
|
||||
|
||||
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
|
||||
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());
|
||||
|
||||
if(thread_k_cluster_id < indOffset)
|
||||
{
|
||||
// consider the thread clusters order, ensure the contiguous locations are accessed
|
||||
// by contiguous Thread-ID
|
||||
index_t offset1 =
|
||||
ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id));
|
||||
index_t offset2 = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
(thread_k_cluster_id + indOffset) * MThreadClusterSize +
|
||||
thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize +
|
||||
(thread_k_cluster_id + indOffset)));
|
||||
|
||||
AccDataType opData1 = type_convert<AccDataType>(block_buffer[offset1]);
|
||||
AccDataType opData2 = type_convert<AccDataType>(block_buffer[offset2]);
|
||||
Accumulation::Calculate(opData1, opData2);
|
||||
block_buffer(offset1) = type_convert<AccDataType>(opData1);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
});
|
||||
|
||||
index_t offset = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize));
|
||||
|
||||
accuData = type_convert<AccDataType>(block_buffer[offset]);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Buffer1dDescType,
|
||||
typename AccDataType,
|
||||
typename IndexDataType,
|
||||
index_t BlockSize,
|
||||
index_t MThreadClusterSize,
|
||||
index_t KThreadClusterSize,
|
||||
bool ReorderThreadClusters,
|
||||
typename OpReduce,
|
||||
bool PropagateNan>
|
||||
struct PartitionedBlockwiseReductionWithIndexOn1dBuffer
|
||||
{
|
||||
static constexpr auto buffer_1d_desc = Buffer1dDescType{};
|
||||
|
||||
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
|
||||
"The product of cluster lengths should be same as BlockSize!");
|
||||
static_assert(KThreadClusterSize > 1, "Parallel reduction need work on at least two elements");
|
||||
|
||||
static_assert(buffer_1d_desc.GetElementSize() == BlockSize,
|
||||
"The buffer size should be the same as BlockSize!");
|
||||
|
||||
using Accumulation =
|
||||
detail::AccumulateWithIndexAndNanCheck<PropagateNan, OpReduce, AccDataType, IndexDataType>;
|
||||
|
||||
// This interface accumulates on both data values and indices
|
||||
template <typename BufferType, typename IdxBufferType>
|
||||
__device__ static void Reduce(BufferType& block_val_buffer,
|
||||
IdxBufferType& block_idx_buffer,
|
||||
AccDataType& accuData,
|
||||
IndexDataType& accuIndex,
|
||||
index_t thread_m_cluster_id,
|
||||
index_t thread_k_cluster_id)
|
||||
{
|
||||
constexpr auto cluster_len_shift = get_shift<KThreadClusterSize>();
|
||||
|
||||
static_for<0, cluster_len_shift, 1>{}([&](auto I) {
|
||||
constexpr index_t indOffset = 1 << I();
|
||||
|
||||
if(thread_k_cluster_id % (indOffset * 2) == 0)
|
||||
{
|
||||
// consider the thread clusters order, ensure the contiguous locations are accessed
|
||||
// by contiguous Thread-ID
|
||||
index_t offset1 =
|
||||
ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_k_cluster_id * MThreadClusterSize + thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
thread_m_cluster_id * KThreadClusterSize + thread_k_cluster_id));
|
||||
index_t offset2 = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(
|
||||
(thread_k_cluster_id + indOffset) * MThreadClusterSize +
|
||||
thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize +
|
||||
(thread_k_cluster_id + indOffset)));
|
||||
|
||||
AccDataType opData1 = type_convert<AccDataType>(block_val_buffer[offset1]);
|
||||
AccDataType opData2 = type_convert<AccDataType>(block_val_buffer[offset2]);
|
||||
IndexDataType currIndex1 = block_idx_buffer[offset1];
|
||||
IndexDataType currIndex2 = block_idx_buffer[offset2];
|
||||
|
||||
Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2);
|
||||
block_val_buffer(offset1) = type_convert<AccDataType>(opData1);
|
||||
block_idx_buffer(offset1) = currIndex1;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
});
|
||||
|
||||
index_t offset = ReorderThreadClusters
|
||||
? buffer_1d_desc.CalculateOffset(make_tuple(thread_m_cluster_id))
|
||||
: buffer_1d_desc.CalculateOffset(
|
||||
make_tuple(thread_m_cluster_id * KThreadClusterSize));
|
||||
|
||||
accuData = type_convert<AccDataType>(block_val_buffer[offset]);
|
||||
accuIndex = block_idx_buffer[offset];
|
||||
}
|
||||
};
|
||||
|
||||
}; // end of namespace ck
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user