mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Use Tuple and vector_type instead of Array for holding tensor data (#30)
* replacing array with tuple and vector for tensor data
This commit is contained in:
@@ -2,16 +2,27 @@
|
||||
#define CK_BLOCKWISE_GEMM_V2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_gemm_v2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N]
|
||||
// C[M, N] += transpose(A[K, M]) * B[K, N]
|
||||
// A and B are visable to the whole block, C is distributed among each thread
|
||||
// If following number are power of 2, index calculation shall be greatly reduced:
|
||||
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
|
||||
// MLevel1ThreadCluster, NLevel1ThreadCluster
|
||||
// Assume:
|
||||
// 1. A:
|
||||
// 1. BlockMatrixA is known at compile-time
|
||||
// 2. ABlockBuffer is DynamicBuffer
|
||||
// 2. B:
|
||||
// 1. BlockMatrixA is known at compile-time
|
||||
// 2. BBlockBuffer is DynamicBuffer
|
||||
// 3. C:
|
||||
// 1. ThreadMatrixC is known at compile-time
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename BlockMatrixA,
|
||||
typename BlockMatrixB,
|
||||
typename ThreadMatrixC,
|
||||
@@ -23,8 +34,12 @@ template <index_t BlockSize,
|
||||
index_t MLevel1ThreadCluster,
|
||||
index_t NLevel1ThreadCluster,
|
||||
index_t ThreadGemmADataPerRead_M,
|
||||
index_t ThreadGemmBDataPerRead_N>
|
||||
struct BlockwiseGemm_km_kn_m0m1n0n1_v1
|
||||
index_t ThreadGemmBDataPerRead_N,
|
||||
typename std::enable_if<BlockMatrixA::IsKnownAtCompileTime() &&
|
||||
BlockMatrixB::IsKnownAtCompileTime() &&
|
||||
ThreadMatrixC::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
|
||||
{
|
||||
struct MatrixIndex
|
||||
{
|
||||
@@ -32,10 +47,49 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
|
||||
index_t col;
|
||||
};
|
||||
|
||||
index_t mMyThreadOffsetA;
|
||||
index_t mMyThreadOffsetB;
|
||||
private:
|
||||
static constexpr auto a_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<0>{})));
|
||||
|
||||
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v1()
|
||||
static constexpr auto b_thread_mtx_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThreadLoop>{}, ThreadMatrixC{}.GetLength(Number<1>{})));
|
||||
|
||||
using AThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
BlockMatrixA,
|
||||
decltype(a_thread_mtx_desc_),
|
||||
Sequence<KPerThreadLoop, MPerThreadSubC>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ThreadGemmADataPerRead_M,
|
||||
AddressSpace::Generic,
|
||||
AddressSpace::Vgpr,
|
||||
1>;
|
||||
|
||||
using BThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
BlockMatrixB,
|
||||
decltype(b_thread_mtx_desc_),
|
||||
Sequence<KPerThreadLoop, NPerThreadSubC>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ThreadGemmBDataPerRead_N,
|
||||
AddressSpace::Generic,
|
||||
AddressSpace::Vgpr,
|
||||
1>;
|
||||
|
||||
MatrixIndex c_thread_begin_mtx_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v1r1()
|
||||
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
|
||||
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.row)},
|
||||
b_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.col)}
|
||||
{
|
||||
static_assert(BlockMatrixA::IsKnownAtCompileTime() &&
|
||||
BlockMatrixB::IsKnownAtCompileTime() &&
|
||||
@@ -51,23 +105,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
|
||||
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
|
||||
|
||||
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent\n");
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
|
||||
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
|
||||
|
||||
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
|
||||
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
|
||||
"wrong! Cannot evenly divide work among\n");
|
||||
"wrong! Cannot evenly divide work among");
|
||||
|
||||
static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] &&
|
||||
ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1],
|
||||
"wrong! ThreadMatrixC lengths is wrong");
|
||||
|
||||
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
mMyThreadOffsetA = BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.row));
|
||||
mMyThreadOffsetB = BlockMatrixB{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.col));
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetThreadMatrixCLengths()
|
||||
@@ -104,103 +153,30 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
|
||||
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
|
||||
}
|
||||
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ void
|
||||
Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run_pipelined_2x2(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABlockBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BBlockBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CThreadBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx_desc = ThreadMatrixC{};
|
||||
|
||||
constexpr auto K = a_block_mtx.GetLength(I0);
|
||||
|
||||
constexpr auto MPerThread = c_thread_mtx.GetLength(I0);
|
||||
constexpr auto NPerThread = c_thread_mtx.GetLength(I1);
|
||||
|
||||
constexpr index_t MPerLevel1Cluster =
|
||||
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
|
||||
constexpr index_t NPerLevel1Cluster =
|
||||
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixA,
|
||||
decltype(a_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
MPerThreadSubC,
|
||||
ThreadGemmADataPerRead_M>{};
|
||||
|
||||
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixB,
|
||||
decltype(b_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
NPerThreadSubC,
|
||||
ThreadGemmBDataPerRead_N>{};
|
||||
|
||||
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_mtx),
|
||||
decltype(b_thread_mtx),
|
||||
decltype(c_thread_mtx)>{};
|
||||
#pragma unroll
|
||||
// loop over k
|
||||
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
{
|
||||
#pragma unroll
|
||||
// read A
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
a_thread_copy.Run(p_a_block +
|
||||
a_block_mtx.CalculateOffset(
|
||||
make_tuple(k_begin, m_repeat * MPerLevel1Cluster)) +
|
||||
mMyThreadOffsetA,
|
||||
p_a_thread + a_thread_mtx.CalculateOffset(
|
||||
make_tuple(0, m_repeat * MPerThreadSubC)));
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
// read B
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
b_thread_copy.Run(p_b_block +
|
||||
b_block_mtx.CalculateOffset(
|
||||
make_tuple(k_begin, n_repeat * NPerLevel1Cluster)) +
|
||||
mMyThreadOffsetB,
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(
|
||||
make_tuple(0, n_repeat * NPerThreadSubC)));
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ void
|
||||
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr auto K = a_block_mtx.GetLength(I0);
|
||||
|
||||
constexpr auto MPerThread = c_thread_mtx.GetLength(I0);
|
||||
constexpr auto NPerThread = c_thread_mtx.GetLength(I1);
|
||||
constexpr auto MPerThread = c_thread_mtx_desc.GetLength(I0);
|
||||
constexpr auto NPerThread = c_thread_mtx_desc.GetLength(I1);
|
||||
|
||||
constexpr index_t MPerLevel1Cluster =
|
||||
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
|
||||
@@ -211,15 +187,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
static_assert(MRepeat == 2 && NRepeat == 2,
|
||||
"wrong! inline asm cannot deal with this GEMM config yet");
|
||||
|
||||
// thread A, B
|
||||
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThread>{}));
|
||||
|
||||
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThread>{}));
|
||||
static_assert(MRepeat == 2 && NRepeat == 2, "wrong! only support 2x2 pipeline");
|
||||
|
||||
// thread A-sub, B-sub
|
||||
constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
|
||||
@@ -234,113 +202,152 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
|
||||
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}),
|
||||
make_tuple(Number<NPerThread>{}, Number<1>{}));
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpaceSize()];
|
||||
auto a_thread_buf = make_static_buffer<FloatA>(a_thread_mtx_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<FloatB>(b_thread_mtx_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixA,
|
||||
decltype(a_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
MPerThreadSubC,
|
||||
ThreadGemmADataPerRead_M>{};
|
||||
|
||||
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixB,
|
||||
decltype(b_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
NPerThreadSubC,
|
||||
ThreadGemmBDataPerRead_N>{};
|
||||
|
||||
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_sub_mtx),
|
||||
decltype(b_thread_sub_mtx),
|
||||
decltype(c_thread_sub_mtx)>{};
|
||||
|
||||
const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA;
|
||||
const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB;
|
||||
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1r1<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_sub_mtx),
|
||||
decltype(b_thread_sub_mtx),
|
||||
decltype(c_thread_sub_mtx)>{};
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy.Run(p_a_block_off, p_a_thread);
|
||||
a_thread_copy_.Run(BlockMatrixA{},
|
||||
make_tuple(I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_mtx_desc_,
|
||||
make_tuple(I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy.Run(p_b_block_off, p_b_thread);
|
||||
b_thread_copy_.Run(BlockMatrixB{},
|
||||
make_tuple(I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_mtx_desc_,
|
||||
make_tuple(I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy.Run(p_b_block_off +
|
||||
b_block_mtx.CalculateOffset(make_tuple(0, NPerLevel1Cluster)),
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
|
||||
b_thread_copy_.Run(BlockMatrixB{},
|
||||
make_tuple(I0, Number<NPerLevel1Cluster>{}),
|
||||
b_block_buf,
|
||||
b_thread_mtx_desc_,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy.Run(p_a_block_off +
|
||||
a_block_mtx.CalculateOffset(make_tuple(0, MPerLevel1Cluster)),
|
||||
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)));
|
||||
a_thread_copy_.Run(BlockMatrixA{},
|
||||
make_tuple(I0, Number<MPerLevel1Cluster>{}),
|
||||
a_block_buf,
|
||||
a_thread_mtx_desc_,
|
||||
make_tuple(I0, Number<MPerThreadSubC>{}),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(
|
||||
p_a_thread,
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
|
||||
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}));
|
||||
|
||||
#pragma unroll
|
||||
// loop over rest of k
|
||||
for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop)
|
||||
{
|
||||
static_for<KPerThreadLoop, K, KPerThreadLoop>{}([&](auto k) {
|
||||
// read A_sub_0
|
||||
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, 0)),
|
||||
p_a_thread);
|
||||
a_thread_copy_.Run(BlockMatrixA{},
|
||||
make_tuple(k, I0),
|
||||
a_block_buf,
|
||||
a_thread_mtx_desc_,
|
||||
make_tuple(I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(
|
||||
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
|
||||
p_b_thread,
|
||||
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, Number<MPerThreadSubC>{}),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(Number<MPerThreadSubC>{}, I0));
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, 0)),
|
||||
p_b_thread);
|
||||
b_thread_copy_.Run(BlockMatrixB{},
|
||||
make_tuple(k, I0),
|
||||
b_block_buf,
|
||||
b_thread_mtx_desc_,
|
||||
make_tuple(I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(
|
||||
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
|
||||
p_c_thread +
|
||||
c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, Number<MPerThreadSubC>{}),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}),
|
||||
c_thread_buf,
|
||||
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy.Run(
|
||||
p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, NPerLevel1Cluster)),
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
|
||||
b_thread_copy_.Run(BlockMatrixB{},
|
||||
make_tuple(k, Number<NPerLevel1Cluster>{}),
|
||||
b_block_buf,
|
||||
b_thread_mtx_desc_,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy.Run(
|
||||
p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, MPerLevel1Cluster)),
|
||||
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)));
|
||||
a_thread_copy_.Run(BlockMatrixA{},
|
||||
make_tuple(k, Number<MPerLevel1Cluster>{}),
|
||||
a_block_buf,
|
||||
a_thread_mtx_desc_,
|
||||
make_tuple(I0, Number<MPerThreadSubC>{}),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(
|
||||
p_a_thread,
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
|
||||
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
|
||||
}
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}));
|
||||
});
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(
|
||||
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
|
||||
p_b_thread,
|
||||
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, Number<MPerThreadSubC>{}),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(Number<MPerThreadSubC>{}, I0));
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(
|
||||
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
|
||||
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, Number<MPerThreadSubC>{}),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, Number<NPerThreadSubC>{}),
|
||||
c_thread_buf,
|
||||
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}));
|
||||
}
|
||||
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
|
||||
constexpr auto I0 = Number<0>{};
|
||||
@@ -354,17 +361,16 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
|
||||
|
||||
if constexpr(MRepeat == 2 && NRepeat == 2)
|
||||
{
|
||||
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread);
|
||||
Run_pipelined_2x2(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
else
|
||||
{
|
||||
Run_naive(p_a_block, p_b_block, p_c_thread);
|
||||
Run_naive(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
#else
|
||||
Run_naive(p_a_block, p_b_block, p_c_thread);
|
||||
Run_naive(a_block_buf, b_block_buf, c_thread_buf);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -6,12 +6,10 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N]
|
||||
// A and B are visable to the whole block, C is distributed among each thread
|
||||
// If following number are power of 2, index calculation shall be greatly reduced:
|
||||
// KPerThread, HPerThread, MLevel0ThreadCluster, NLevel0ThreadCluster,
|
||||
// MLevel1ThreadCluster, NLevel1ThreadCluster
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename BlockMatrixA,
|
||||
typename BlockMatrixB,
|
||||
typename ThreadMatrixC,
|
||||
@@ -30,9 +28,34 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
|
||||
index_t w;
|
||||
};
|
||||
|
||||
index_t mMyThreadOffsetA;
|
||||
// HACK: fix this @Jing Zhang
|
||||
static constexpr index_t KPerThreadSubC = 4;
|
||||
|
||||
static constexpr auto a_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
|
||||
|
||||
static constexpr auto b_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
|
||||
|
||||
static constexpr auto c_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
|
||||
|
||||
using AThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
BlockMatrixA,
|
||||
decltype(a_thread_mtx_),
|
||||
Sequence<EPerThreadLoop, KPerThreadSubC>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ThreadGemmADataPerRead_K,
|
||||
AddressSpace::Generic,
|
||||
AddressSpace::Vgpr,
|
||||
1>;
|
||||
|
||||
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v3()
|
||||
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
|
||||
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)}
|
||||
{
|
||||
static_assert(BlockMatrixA::IsKnownAtCompileTime() &&
|
||||
BlockMatrixB::IsKnownAtCompileTime() &&
|
||||
@@ -61,11 +84,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
|
||||
|
||||
static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster,
|
||||
"wrong! wrong blocksize\n");
|
||||
|
||||
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
mMyThreadOffsetA =
|
||||
BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.k * KPerThread));
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetThreadMatrixCLengths()
|
||||
@@ -91,37 +109,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
|
||||
return MatrixIndex{k_thread_id, h_thread_id, w_thread_id};
|
||||
}
|
||||
|
||||
template <typename SrcDesc,
|
||||
typename DstDesc,
|
||||
index_t NSliceRow,
|
||||
index_t NSliceCol,
|
||||
index_t DataPerAccess>
|
||||
struct ThreadwiseSliceCopy_a
|
||||
{
|
||||
template <typename Data>
|
||||
__device__ static void Run(const Data* p_src, Data* p_dst)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
using vector_t = typename vector_type_maker<Data, DataPerAccess>::type::type;
|
||||
|
||||
static_for<0, NSliceRow, 1>{}([&](auto i) {
|
||||
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
|
||||
constexpr auto src_offset = SrcDesc{}.CalculateOffset(make_tuple(i, j));
|
||||
constexpr auto dst_offset = DstDesc{}.CalculateOffset(make_tuple(i, j));
|
||||
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ void
|
||||
Run_naive(const FloatA* p_a_block, const FloatB* p_b_thread, FloatC* p_c_thread) const
|
||||
template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BThreadBuffer& b_thread_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABlockBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BThreadBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CThreadBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -132,8 +131,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
|
||||
|
||||
constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
|
||||
|
||||
constexpr auto KPerThreadSubC = 4;
|
||||
|
||||
// HACK: fix this @Jing Zhang
|
||||
constexpr auto HoPerThreadSubC = 2;
|
||||
constexpr auto WoPerThreadSubC = 2;
|
||||
|
||||
@@ -141,63 +139,53 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
|
||||
static_assert(HPerThread % HoPerThreadSubC == 0, "");
|
||||
static_assert(WPerThread % WoPerThreadSubC == 0, "");
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
|
||||
// thread A buffer for GEMM
|
||||
StaticBuffer<FloatA, a_thread_mtx_.GetElementSpaceSize()> a_thread_buf;
|
||||
|
||||
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
|
||||
|
||||
constexpr auto c_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
|
||||
|
||||
constexpr auto a_thread_copy = ThreadwiseSliceCopy_a<BlockMatrixA,
|
||||
decltype(a_thread_mtx),
|
||||
EPerThreadLoop,
|
||||
KPerThreadSubC,
|
||||
ThreadGemmADataPerRead_K>{};
|
||||
|
||||
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<decltype(a_thread_mtx),
|
||||
decltype(b_thread_mtx),
|
||||
decltype(c_thread_mtx),
|
||||
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_mtx_),
|
||||
decltype(b_thread_mtx_),
|
||||
decltype(c_thread_mtx_),
|
||||
HoPerThreadSubC,
|
||||
WoPerThreadSubC>{};
|
||||
// loop over k
|
||||
#pragma unroll
|
||||
for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t k_begin = 0; k_begin < KPerThread; k_begin += KPerThreadSubC)
|
||||
{
|
||||
a_thread_copy.Run(p_a_block +
|
||||
a_block_mtx.CalculateOffset(make_tuple(e_begin, k_begin)) +
|
||||
mMyThreadOffsetA,
|
||||
p_a_thread);
|
||||
|
||||
#pragma unroll
|
||||
for(index_t h_begin = 0; h_begin < HPerThread; h_begin += HoPerThreadSubC)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t w_begin = 0; w_begin < WPerThread; w_begin += WoPerThreadSubC)
|
||||
{
|
||||
threadwise_gemm.Run(p_a_thread,
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(
|
||||
e_begin, 0, h_begin, w_begin)),
|
||||
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(
|
||||
k_begin, 0, h_begin, w_begin)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) {
|
||||
static_for<0, KPerThread, KPerThreadSubC>{}([&](auto k_begin) {
|
||||
|
||||
a_thread_copy_.Run(a_block_mtx,
|
||||
make_tuple(e_begin, k_begin),
|
||||
a_block_buf,
|
||||
a_thread_mtx_,
|
||||
make_tuple(I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
static_for<0, HPerThread, HoPerThreadSubC>{}([&](auto h_begin) {
|
||||
static_for<0, WPerThread, WoPerThreadSubC>{}([&](auto w_begin) {
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(e_begin, I0, h_begin, w_begin),
|
||||
c_thread_buf,
|
||||
make_tuple(k_begin, I0, h_begin, w_begin));
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_thread, FloatC* p_c_thread) const
|
||||
template <typename ABlockSliceMoveStepIdx>
|
||||
__device__ void MoveASliceWindow(const BlockMatrixA&,
|
||||
const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
|
||||
{
|
||||
Run_naive(p_a_block, p_b_thread, p_c_thread);
|
||||
a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx);
|
||||
}
|
||||
|
||||
private:
|
||||
MatrixIndex c_thread_begin_mtx_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -5,9 +5,10 @@
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_v2.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "blockwise_gemm_v2.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -256,19 +257,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
make_tuple(Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{}));
|
||||
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemm_km_kn_m0m1n0n1_v1<BlockSize,
|
||||
decltype(a_k_m_block_desc),
|
||||
decltype(b_k_n_block_desc),
|
||||
decltype(c_m0m1_n0n1_thread_desc),
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
MLevel0Cluster,
|
||||
NLevel0Cluster,
|
||||
MLevel1Cluster,
|
||||
NLevel1Cluster,
|
||||
MPerThread,
|
||||
NPerThread>{};
|
||||
BlockwiseGemm_km_kn_m0m1n0n1_v1r1<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k_m_block_desc),
|
||||
decltype(b_k_n_block_desc),
|
||||
decltype(c_m0m1_n0n1_thread_desc),
|
||||
MPerThread,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
MLevel0Cluster,
|
||||
NLevel0Cluster,
|
||||
MLevel1Cluster,
|
||||
NLevel1Cluster,
|
||||
MPerThread,
|
||||
NPerThread>{};
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
@@ -281,10 +285,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
|
||||
|
||||
// register allocation for output
|
||||
FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
|
||||
auto c_thread_buf =
|
||||
make_static_buffer<FloatAcc>(c_m0m1_n0n1_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread);
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m0m1_n0n1_thread_desc),
|
||||
Sequence<MRepeat * MPerThread, NRepeat * NPerThread>>{}
|
||||
.Run(c_m0m1_n0n1_thread_desc, make_tuple(I0, I0), c_thread_buf, FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
|
||||
@@ -300,6 +307,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
|
||||
BGlobalMoveSliceWindowIteratorHacks{};
|
||||
|
||||
FloatAB* p_a_block_even = p_a_block_double;
|
||||
FloatAB* p_b_block_even = p_b_block_double;
|
||||
|
||||
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
|
||||
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer(p_a_block_even);
|
||||
auto b_block_even_buf = make_dynamic_buffer(p_b_block_even);
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer(p_a_block_odd);
|
||||
auto b_block_odd_buf = make_dynamic_buffer(p_b_block_odd);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks);
|
||||
@@ -311,12 +330,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
FloatAB* p_a_block_even = p_a_block_double;
|
||||
FloatAB* p_b_block_even = p_b_block_double;
|
||||
|
||||
FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size;
|
||||
FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size;
|
||||
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
@@ -340,7 +353,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread);
|
||||
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd);
|
||||
@@ -363,7 +376,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread);
|
||||
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even);
|
||||
@@ -390,7 +403,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
|
||||
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size);
|
||||
@@ -399,16 +412,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_a_block_double + a_block_space_size,
|
||||
p_b_block_double + b_block_space_size,
|
||||
p_c_thread);
|
||||
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
|
||||
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
@@ -461,7 +472,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
|
||||
n_thread_data_on_global % N1))
|
||||
.Run(c_m0_m1_n0_n1_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
p_c_thread,
|
||||
c_thread_buf,
|
||||
c_m0_m1_n0_n1_global_desc,
|
||||
p_c_global,
|
||||
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
|
||||
|
||||
@@ -145,17 +145,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
||||
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
|
||||
decltype(a_e_k_block_desc),
|
||||
decltype(b_e_n_ho_wo_block_desc),
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K>{};
|
||||
auto blockwise_gemm = BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_e_k_block_desc),
|
||||
decltype(b_e_n_ho_wo_block_desc),
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K>{};
|
||||
|
||||
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
@@ -223,11 +225,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
|
||||
FloatAB* p_a_block = p_shared_block;
|
||||
|
||||
// register allocation for output
|
||||
FloatAcc p_c_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()];
|
||||
auto a_block_buf = make_dynamic_buffer(p_a_block);
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_c_thread);
|
||||
// register allocation for output
|
||||
StaticBuffer<FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()> c_thread_buf;
|
||||
|
||||
// initialize output thread tensor
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
|
||||
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
|
||||
|
||||
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
|
||||
|
||||
@@ -242,12 +249,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
|
||||
BGlobalMoveSliceWindowIteratorHacks{};
|
||||
|
||||
constexpr auto b_thread_space_size = b_e_n_ho_wo_thread_desc.GetElementSpaceSize();
|
||||
FloatAB p_b_thread[b_thread_space_size * 2];
|
||||
// double regsiter buffer for b
|
||||
StaticBuffer<FloatAB, b_e_n_ho_wo_thread_desc.GetElementSpaceSize()> b_thread_even_buf,
|
||||
b_thread_odd_buf;
|
||||
|
||||
FloatAB* p_b_thread_double = p_b_thread;
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
// LDS double buffer: preload data
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_e_k_global_iterator_hacks);
|
||||
|
||||
@@ -255,7 +261,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
p_b_global,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
p_b_thread_double,
|
||||
b_thread_even_buf,
|
||||
b_e_n_ho_wo_global_iterator_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_e_k_desc, p_a_block);
|
||||
@@ -263,13 +269,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
|
||||
__syncthreads();
|
||||
|
||||
index_t b_block_data_begin = 0;
|
||||
|
||||
#if 1
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
FloatAB* p_b_thread_even = p_b_thread_double;
|
||||
FloatAB* p_b_thread_odd = p_b_thread_double + b_thread_space_size;
|
||||
index_t e_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
@@ -283,16 +285,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
p_b_global,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
p_b_thread_odd,
|
||||
b_thread_odd_buf,
|
||||
b_e_n_ho_wo_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
|
||||
p_b_thread_even,
|
||||
p_c_thread);
|
||||
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||
|
||||
b_block_data_begin += EPerBlock;
|
||||
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
|
||||
|
||||
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
|
||||
b_thread_slice_copy_step);
|
||||
@@ -301,18 +301,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
p_b_global,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
p_b_thread_even,
|
||||
b_thread_even_buf,
|
||||
b_e_n_ho_wo_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
|
||||
p_b_thread_odd,
|
||||
p_c_thread);
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
|
||||
|
||||
b_block_data_begin += EPerBlock;
|
||||
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
|
||||
|
||||
} while(b_block_data_begin < E - 2 * EPerBlock);
|
||||
e_block_data_begin += 2 * EPerBlock;
|
||||
|
||||
} while(e_block_data_begin < E - 2 * EPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
@@ -325,34 +324,23 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
p_b_global,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
p_b_thread_double + b_thread_space_size,
|
||||
b_thread_odd_buf,
|
||||
b_e_n_ho_wo_global_iterator_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
|
||||
p_b_thread_double,
|
||||
p_c_thread);
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||
|
||||
b_block_data_begin += EPerBlock;
|
||||
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
|
||||
p_b_thread_double + b_thread_space_size,
|
||||
p_c_thread);
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)),
|
||||
p_b_thread_double,
|
||||
p_c_thread);
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// output: register to global memory
|
||||
{
|
||||
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
|
||||
@@ -380,12 +368,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
|
||||
.Run(c_k_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
p_c_thread,
|
||||
c_thread_buf,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
p_c_global,
|
||||
c_k_n_ho_wo_global_tensor_iterator_hacks);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// pass tensor descriptor by reference
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
|
||||
#define CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Assume:
|
||||
// 1. Desc is known at compile-time
|
||||
// 2. Buffer is StaticBuffer
|
||||
// 3. OriginIdx is known at compile-time
|
||||
// 4. use #-iterator
|
||||
template <typename Data,
|
||||
typename Desc,
|
||||
typename SliceLengths,
|
||||
typename std::enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseDynamicTensorSliceSet_v1
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
template <typename OriginIdx, typename Buffer>
|
||||
__device__ void Run(const Desc&, const OriginIdx&, Buffer& buf, const Data& initial_value) const
|
||||
{
|
||||
static_assert(Desc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<remove_reference_t<OriginIdx>>>::value,
|
||||
"wrong! OriginIdx need to be known at compile-time");
|
||||
|
||||
// Desc is known at compile-time
|
||||
constexpr auto desc = remove_cv_t<remove_reference_t<Desc>>{};
|
||||
|
||||
// OriginIdx is known at compile-time
|
||||
constexpr auto origin_idx = to_multi_index(OriginIdx{});
|
||||
|
||||
static_ford<SliceLengths>{}([&](auto access_idx) {
|
||||
constexpr auto coord = make_dynamic_tensor_coordinate(desc, origin_idx + access_idx);
|
||||
|
||||
constexpr bool is_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord);
|
||||
|
||||
constexpr index_t offset = coord.GetOffset();
|
||||
|
||||
if constexpr(is_valid)
|
||||
{
|
||||
buf(Number<offset>{}) = initial_value;
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -7,6 +7,15 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
|
||||
// and sometimes useless instructions:
|
||||
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
|
||||
// instead
|
||||
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
|
||||
// tensor coordinate instead
|
||||
// 3. Don't use a pointer to VGPR buffer, use vector instead
|
||||
|
||||
namespace detail {
|
||||
// TODO: How to fix this? It uses an struct instead of lambda because lambda
|
||||
// doesn't have constructor
|
||||
template <index_t VectorDim, index_t ScalarPerVector>
|
||||
@@ -26,12 +35,17 @@ struct lambda_scalar_step_in_vector
|
||||
return (i == VectorDim) ? 1 : 0;
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
// this version is less likely to have scratch memory issue, due to:
|
||||
// 1. It does not keep reference to tensor descriptor
|
||||
// 2. It does not construct new tensor coordinate for this->Run()
|
||||
// Assume src_slice_origin_idx is 0
|
||||
// TODO: support non-zero src_slice_oring_idx
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is known at compile-time
|
||||
// 2. SrcBuffer is StaticBuffer
|
||||
// 3. SrcSliceOrginIdx is known at compile-time
|
||||
// 2. dst:
|
||||
// 1. DstDesc is not known at compile-time
|
||||
// 2. DstBuffer is DynamicBuffer
|
||||
// 3. DstSliceOrginIdx is not known at compile time
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
@@ -69,10 +83,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx, typename DstIteratorHacks>
|
||||
template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstIteratorHacks>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcData* p_src,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstData* p_dst,
|
||||
const DstIteratorHacks& dst_iterator_hacks)
|
||||
@@ -84,9 +98,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<SrcSliceOriginIdx>>>::value,
|
||||
"wrong! SrcSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<SrcData>>>::value,
|
||||
"wrong! SrcBuffer data type is wrong");
|
||||
|
||||
// SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
|
||||
constexpr auto src_slice_origin_idx = SrcSliceOriginIdx{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -94,10 +114,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_scalar_step_in_vector =
|
||||
generate_sequence(lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
@@ -178,12 +198,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset =
|
||||
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
|
||||
i * dst_scalar_step_in_vector);
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector);
|
||||
|
||||
dst_vector.template AsType<DstData>()(i) =
|
||||
type_convert<DstData>{}(p_src[Number<src_offset>{}]);
|
||||
type_convert<DstData>{}(src_buf[Number<src_offset>{}]);
|
||||
});
|
||||
|
||||
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
@@ -284,7 +303,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
@@ -359,10 +378,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
DstCoord dst_slice_origin_coord_;
|
||||
}; // namespace ck
|
||||
|
||||
// this version is less likely to have scratch memory issue, due to:
|
||||
// 1. It does not keep reference to tensor descriptor
|
||||
// 2. It does not construct new tensor coordinate for this->Run()
|
||||
// Assume dst_slice_origin_idx is 0
|
||||
// Assume:
|
||||
// 1. src_desc is not known at compile-time
|
||||
// 2. dst_desc is known at compile-time
|
||||
// 3. src_slice_origin_idx is not known at compile-time
|
||||
// 4. dst_slice_origin_idx is known at compile-time and it's 0
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
@@ -399,12 +419,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename DstSliceOriginIdx, typename SrcIteratorHacks>
|
||||
template <typename DstBuffer, typename DstSliceOriginIdx, typename SrcIteratorHacks>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcData* p_src,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstData* p_dst,
|
||||
DstBuffer& dst_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
{
|
||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||
@@ -414,6 +434,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstSliceOriginIdx>>>::value,
|
||||
"wrong! DstSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<DstData>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
// DstDesc and dst_slice_origin_idx are known at compile-time
|
||||
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
|
||||
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
|
||||
@@ -424,10 +448,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_scalar_step_in_vector =
|
||||
generate_sequence(lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -541,7 +565,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
|
||||
i * src_scalar_step_in_vector);
|
||||
|
||||
p_dst[Number<dst_offset>{}] = src_vector.template AsType<SrcData>()[i];
|
||||
dst_buf(Number<dst_offset>{}) = src_vector.template AsType<SrcData>()[i];
|
||||
});
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
@@ -590,7 +614,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void Run(const SrcDesc& src_desc, const SrcData* p_src, DstData* p_dst)
|
||||
template <typename DstBuffer, typename DstSliceOriginIdx>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcData* p_src,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
|
||||
|
||||
@@ -600,7 +629,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
Run(src_desc, p_src, p_dst, src_iterator_hacks);
|
||||
Run(src_desc, p_src, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
@@ -610,7 +639,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -685,12 +714,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
SrcCoord src_slice_origin_coord_;
|
||||
}; // namespace ck
|
||||
|
||||
// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
|
||||
// and sometimes useless instructions
|
||||
// 1. It does not keep reference to tensor descriptor
|
||||
// 2. It does not construct new tensor coordinate for this->Run()
|
||||
// 3. It does not use pointer for VGPR thread buffer
|
||||
// 4. It calculate offset for thread buffer directly, instead of moving the coordinate
|
||||
// Assume:
|
||||
// 1. src_desc and dst_desc are not known at compile-time
|
||||
// 2. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
// 3. Use thread buffer
|
||||
template <typename SliceLengths,
|
||||
InMemoryDataOperation DstInMemOp,
|
||||
typename SrcData,
|
||||
@@ -737,6 +764,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
static_assert(DstAddressSpace == AddressSpace::Global or
|
||||
DstAddressSpace == AddressSpace::Lds,
|
||||
"wrong!");
|
||||
|
||||
// TODO: fix this
|
||||
static_assert(is_same<SrcData, DstData>::value,
|
||||
"wrong! current implementation assume SrcData and DstData are same type");
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
@@ -760,10 +791,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_scalar_step_in_vector =
|
||||
generate_sequence(lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -838,11 +869,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
return src_data_idx;
|
||||
}();
|
||||
|
||||
// copy data
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
|
||||
// copy data from src_buf to src_tmp_vector
|
||||
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
|
||||
|
||||
using src_vector_t =
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
|
||||
using src_vector_t = typename decltype(src_tmp_vector)::type;
|
||||
|
||||
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
src_desc, src_slice_origin_coord_);
|
||||
@@ -850,14 +880,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
if constexpr(SrcAddressSpace == AddressSpace::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
amd_buffer_load_v2<SrcData, SrcScalarPerVector>(
|
||||
p_src,
|
||||
src_slice_origin_coord_.GetOffset(),
|
||||
is_src_valid,
|
||||
src_desc.GetElementSpaceSize());
|
||||
#else
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_slice_origin_coord_.GetOffset()])
|
||||
: src_vector_t{0};
|
||||
@@ -865,17 +895,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
}
|
||||
else
|
||||
{
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
is_src_valid ? *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_slice_origin_coord_.GetOffset()])
|
||||
: src_vector_t{0};
|
||||
}
|
||||
|
||||
// copy data from src_tmp_vector to buffer_
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
buffer_(Number<buffer_offset>{}) = src_vector.template AsType<SrcData>()[i];
|
||||
buffer_(Number<buffer_offset>{}) = src_tmp_vector.template AsType<SrcData>()[i];
|
||||
});
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
@@ -937,10 +968,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
// src scalar per access on each dim
|
||||
// TODO: don't use this
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_scalar_step_in_vector =
|
||||
generate_sequence(lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
@@ -1026,20 +1057,21 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
DstInMemOp == InMemoryDataOperation::Set,
|
||||
"wrong! hardcoded for ds_write");
|
||||
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
vector_type_maker_t<DstData, DstScalarPerVector> dst_tmp_vector;
|
||||
|
||||
// copy data from buffer_ to dst_tmp_vector
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
|
||||
|
||||
dst_vector.template AsType<DstData>()(i) = buffer_[Number<buffer_offset>{}];
|
||||
dst_tmp_vector.template AsType<DstData>()(i) = buffer_[Number<buffer_offset>{}];
|
||||
});
|
||||
|
||||
using DstVectorType =
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
|
||||
using dst_vector_t = typename decltype(dst_tmp_vector)::type;
|
||||
|
||||
*reinterpret_cast<DstVectorType*>(p_dst + dst_slice_origin_coord_.GetOffset()) =
|
||||
dst_vector.template AsType<DstVectorType>()[Number<0>{}];
|
||||
// copy data from dst_tmp_vector to dst_buf
|
||||
*reinterpret_cast<dst_vector_t*>(p_dst + dst_slice_origin_coord_.GetOffset()) =
|
||||
dst_tmp_vector.template AsType<dst_vector_t>()[Number<0>{}];
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
@@ -1123,7 +1155,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -1185,7 +1217,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
@@ -1274,7 +1306,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
@@ -1297,11 +1328,203 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
|
||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||
|
||||
StaticallyIndexedArray<SrcData, buffer_size_> buffer_;
|
||||
StaticBuffer<SrcData, buffer_size_> buffer_;
|
||||
|
||||
SrcCoord src_slice_origin_coord_;
|
||||
DstCoord dst_slice_origin_coord_;
|
||||
};
|
||||
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is known at compile-time
|
||||
// 2. SrcBuffer is DynamicBuffer
|
||||
// 3. src_ref_idx is known at run-time
|
||||
// 4. SrcRefToOriginDisplacement is known at compile-time
|
||||
// 5. use #-iterator
|
||||
// 2. dst:
|
||||
// 1. DstDesc is known at compile-time
|
||||
// 2. DstBuffer is StaticBuffer
|
||||
// 3. DstOriginIdx is known at compile-time
|
||||
// 4. use direct address calculation
|
||||
// 3. vector access on src
|
||||
template <
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
index_t SrcScalarStrideInVector,
|
||||
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseDynamicTensorSliceTransfer_v4
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
|
||||
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4(const Index& src_ref_idx)
|
||||
: src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
}
|
||||
|
||||
template <typename SrcRefToOriginDisplacement,
|
||||
typename DstOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcRefToOriginDisplacement&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstOriginIdx&,
|
||||
DstBuffer& dst_buf) const
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<SrcData>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<DstData>>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
|
||||
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
|
||||
|
||||
static_assert(
|
||||
is_known_at_compile_time<
|
||||
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstOriginIdx>>>::value,
|
||||
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
|
||||
"at compile-time");
|
||||
|
||||
// SrcDesc and DstDesc are known at compile-time
|
||||
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
|
||||
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
|
||||
|
||||
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
|
||||
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
|
||||
constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
// scalar per access of each dim
|
||||
constexpr auto src_scalar_per_access = generate_sequence_v2(
|
||||
[&](auto i) constexpr {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return Number<SrcScalarPerVector>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<1>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// scalar step (if steping on SrcVectorDim) of each dim
|
||||
constexpr auto src_scalar_step_in_vector = generate_sequence_v2(
|
||||
[&](auto i) constexpr {
|
||||
if constexpr(i == SrcVectorDim)
|
||||
{
|
||||
return Number<1>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<0>{};
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
|
||||
#if 0
|
||||
// TODO: unable to compile
|
||||
// position in slice window
|
||||
constexpr auto data_to_origin_disp_idx =
|
||||
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
#else
|
||||
// position in slice window
|
||||
constexpr auto data_to_origin_disp_idx =
|
||||
ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
|
||||
#endif
|
||||
|
||||
// src coordinate
|
||||
constexpr auto src_ref_to_data_disp_idx =
|
||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
||||
|
||||
constexpr auto src_ref_to_data_disp_coord_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx);
|
||||
|
||||
auto src_data_coord = src_ref_coord_;
|
||||
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
|
||||
|
||||
// copy data from src_buf into src_tmp_buffer
|
||||
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
|
||||
|
||||
using src_vector_t = typename decltype(src_tmp_vector)::type;
|
||||
|
||||
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
src_desc, src_data_coord);
|
||||
|
||||
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
is_src_valid ? src_buf.template Get<src_vector_t>(src_data_coord.GetOffset())
|
||||
: src_vector_t{0};
|
||||
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
|
||||
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
dst_tmp_vector.template AsType<DstData>()(i) =
|
||||
type_convert<DstData>{}(src_tmp_vector.template AsType<SrcData>()[i]);
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcSliceMoveStepIdx>
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc&,
|
||||
const SrcSliceMoveStepIdx& src_slice_move_step_idx)
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
|
||||
const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, to_multi_index(src_slice_move_step_idx));
|
||||
|
||||
move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
||||
}
|
||||
|
||||
private:
|
||||
SrcCoord src_ref_coord_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -6,100 +6,52 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename Float, typename Desc>
|
||||
__device__ void threadwise_matrix_set_zero_v2(Desc, Float* __restrict__ p_thread)
|
||||
{
|
||||
static_assert(Desc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto desc = Desc{};
|
||||
|
||||
constexpr auto M = desc.GetLength(I0);
|
||||
constexpr auto N = desc.GetLength(I1);
|
||||
|
||||
static_for<0, M, 1>{}([&](auto i) {
|
||||
static_for<0, N, 1>{}([&](auto j) {
|
||||
constexpr auto offset = desc.CalculateOffset(make_tuple(i, j));
|
||||
|
||||
p_thread[offset] = Float(0);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcDesc,
|
||||
typename DstDesc,
|
||||
index_t NSliceRow,
|
||||
index_t NSliceCol,
|
||||
index_t DataPerAccess>
|
||||
struct ThreadwiseMatrixSliceCopy_v2
|
||||
{
|
||||
template <typename Data>
|
||||
__device__ static void Run(const Data* p_src, Data* p_dst)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
using vector_t = typename vector_type_maker<Data, DataPerAccess>::type::type;
|
||||
|
||||
static_for<0, NSliceRow, 1>{}([&](auto i) {
|
||||
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
|
||||
constexpr auto src_offset = SrcDesc{}.CalculateOffset(make_tuple(i, j));
|
||||
constexpr auto dst_offset = DstDesc{}.CalculateOffset(make_tuple(i, j));
|
||||
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// C[M, N] += transpose(A[K, M]) * B[K, N]
|
||||
// Element of matrix can be vectorized data
|
||||
template <typename ADesc,
|
||||
// Assume:
|
||||
// 1. ADesc, BDesc, CDesc are known at compile-time
|
||||
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
|
||||
template <typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc,
|
||||
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseGemm_km_kn_mn_v1
|
||||
struct ThreadwiseGemm_km_kn_mn_v1r1
|
||||
{
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
|
||||
template <typename ABuffer,
|
||||
typename AOriginIdx,
|
||||
typename BBuffer,
|
||||
typename BOriginIdx,
|
||||
typename CBuffer,
|
||||
typename COriginIdx>
|
||||
__device__ static void Run(const ABuffer& a_buf,
|
||||
AOriginIdx,
|
||||
const BBuffer& b_buf,
|
||||
BOriginIdx,
|
||||
CBuffer& c_buf,
|
||||
COriginIdx)
|
||||
{
|
||||
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
static_assert(
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
|
||||
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
|
||||
|
||||
constexpr auto M = CDesc{}.GetLength(I0);
|
||||
constexpr auto N = CDesc{}.GetLength(I1);
|
||||
constexpr auto K = ADesc{}.GetLength(I0);
|
||||
|
||||
static_for<0, K, 1>{}([&](auto k) {
|
||||
static_for<0, M, 1>{}([&](auto m) {
|
||||
static_for<0, N, 1>{}([&](auto n) {
|
||||
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(k, m));
|
||||
constexpr auto b_offset = BDesc{}.CalculateOffset(make_tuple(k, n));
|
||||
constexpr auto c_offset = CDesc{}.CalculateOffset(make_tuple(m, n));
|
||||
|
||||
p_c[c_offset] +=
|
||||
inner_product_with_conversion<FloatC>{}(p_a[a_offset], p_b[b_offset]);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
|
||||
{
|
||||
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -110,62 +62,82 @@ struct ThreadwiseGemm_km_kn_mn_v1
|
||||
constexpr auto N = CDesc{}.GetLength(I1);
|
||||
constexpr auto K = ADesc{}.GetLength(I0);
|
||||
|
||||
static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet");
|
||||
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
|
||||
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
|
||||
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
|
||||
|
||||
static_for<0, K, 1>{}([&](auto k) {
|
||||
static_for<0, M, 1>{}([&](auto m) {
|
||||
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(k, m));
|
||||
constexpr index_t a_offset =
|
||||
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m));
|
||||
|
||||
#if 0
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(k, I0));
|
||||
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(k, I1));
|
||||
constexpr index_t b_offset_0 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
|
||||
constexpr index_t b_offset_1 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
|
||||
|
||||
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(m, I0));
|
||||
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(m, I1));
|
||||
constexpr index_t c_offset_0 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
|
||||
constexpr index_t c_offset_1 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
|
||||
|
||||
amd_assembly_outer_product_1x2(p_a[a_offset],
|
||||
p_b[b_offset_0],
|
||||
p_b[b_offset_1],
|
||||
p_c[c_offset_0],
|
||||
p_c[c_offset_1]);
|
||||
amd_assembly_outer_product_1x2(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset_0>{}],
|
||||
b_buf[Number<b_offset_1>{}],
|
||||
c_buf(Number<c_offset_0>{}),
|
||||
c_buf(Number<c_offset_1>{}));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(k, I0));
|
||||
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(k, I1));
|
||||
constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(k, I2));
|
||||
constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(k, I3));
|
||||
constexpr index_t b_offset_0 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
|
||||
constexpr index_t b_offset_1 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
|
||||
constexpr index_t b_offset_2 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I2));
|
||||
constexpr index_t b_offset_3 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I3));
|
||||
|
||||
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(m, I0));
|
||||
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(m, I1));
|
||||
constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(m, I2));
|
||||
constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(m, I3));
|
||||
constexpr index_t c_offset_0 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
|
||||
constexpr index_t c_offset_1 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
|
||||
constexpr index_t c_offset_2 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I2));
|
||||
constexpr index_t c_offset_3 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I3));
|
||||
|
||||
amd_assembly_outer_product_1x4(p_a[a_offset],
|
||||
p_b[b_offset_0],
|
||||
p_b[b_offset_1],
|
||||
p_b[b_offset_2],
|
||||
p_b[b_offset_3],
|
||||
p_c[c_offset_0],
|
||||
p_c[c_offset_1],
|
||||
p_c[c_offset_2],
|
||||
p_c[c_offset_3]);
|
||||
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset_0>{}],
|
||||
b_buf[Number<b_offset_1>{}],
|
||||
b_buf[Number<b_offset_2>{}],
|
||||
b_buf[Number<b_offset_3>{}],
|
||||
c_buf(Number<c_offset_0>{}),
|
||||
c_buf(Number<c_offset_1>{}),
|
||||
c_buf(Number<c_offset_2>{}),
|
||||
c_buf(Number<c_offset_3>{}));
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
static_for<0, N, 1>{}([&](auto n) {
|
||||
|
||||
constexpr index_t b_offset =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, n));
|
||||
constexpr index_t c_offset =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n));
|
||||
|
||||
amd_assembly_inner_product(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset>{}],
|
||||
c_buf(Number<c_offset>{}));
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
|
||||
{
|
||||
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
Run_amd_asm(p_a, p_b, p_c);
|
||||
#else
|
||||
Run_source(p_a, p_b, p_c);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -6,35 +6,15 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename Float, typename Desc>
|
||||
__device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread)
|
||||
{
|
||||
static_assert(Desc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto desc = Desc{};
|
||||
|
||||
constexpr auto K = desc.GetLength(I0);
|
||||
constexpr auto H = desc.GetLength(I2);
|
||||
constexpr auto W = desc.GetLength(I3);
|
||||
|
||||
static_for<0, K, 1>{}([&](auto i) {
|
||||
static_for<0, H, 1>{}([&](auto j) {
|
||||
static_for<0, W, 1>{}([&](auto k) {
|
||||
constexpr auto offset = desc.CalculateOffset(make_tuple(i, 0, j, k));
|
||||
|
||||
p_thread[offset] = Float(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C[M, N] += transpose(A[K, M]) * B[K, N]
|
||||
// Element of matrix can be vectorized data
|
||||
template <typename ADesc,
|
||||
// Assume:
|
||||
// 1. ADesc, BDesc, CDesc are known at compile-time
|
||||
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
|
||||
template <typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc,
|
||||
index_t H,
|
||||
@@ -44,13 +24,37 @@ template <typename ADesc,
|
||||
bool>::type = false>
|
||||
struct ThreadwiseGemm_km_kn_mn_v3
|
||||
{
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
|
||||
template <typename ABuffer,
|
||||
typename AOriginIdx,
|
||||
typename BBuffer,
|
||||
typename BOriginIdx,
|
||||
typename CBuffer,
|
||||
typename COriginIdx>
|
||||
__device__ static void Run(const ABuffer& a_buf,
|
||||
AOriginIdx,
|
||||
const BBuffer& b_buf,
|
||||
BOriginIdx,
|
||||
CBuffer& c_buf,
|
||||
COriginIdx)
|
||||
{
|
||||
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
|
||||
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
@@ -59,79 +63,100 @@ struct ThreadwiseGemm_km_kn_mn_v3
|
||||
constexpr auto E = ADesc{}.GetLength(I0);
|
||||
constexpr auto K = ADesc{}.GetLength(I1);
|
||||
|
||||
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
|
||||
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
|
||||
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
|
||||
|
||||
static_for<0, E, 1>{}([&](auto e) {
|
||||
static_for<0, K, 1>{}([&](auto k) {
|
||||
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(e, k));
|
||||
constexpr index_t a_offset =
|
||||
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(e, k));
|
||||
|
||||
if constexpr(H == 2 && W == 2)
|
||||
{
|
||||
constexpr index_t b_offset_0 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0));
|
||||
constexpr index_t b_offset_1 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 1));
|
||||
constexpr index_t b_offset_2 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0));
|
||||
constexpr index_t b_offset_3 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 1));
|
||||
|
||||
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 0));
|
||||
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 1));
|
||||
constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 0));
|
||||
constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 1));
|
||||
constexpr index_t c_offset_0 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0));
|
||||
constexpr index_t c_offset_1 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 1));
|
||||
constexpr index_t c_offset_2 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0));
|
||||
constexpr index_t c_offset_3 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 1));
|
||||
|
||||
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 0));
|
||||
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 1));
|
||||
constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 0));
|
||||
constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 1));
|
||||
|
||||
amd_assembly_outer_product_1x4(p_a[a_offset],
|
||||
p_b[b_offset_0],
|
||||
p_b[b_offset_1],
|
||||
p_b[b_offset_2],
|
||||
p_b[b_offset_3],
|
||||
p_c[c_offset_0],
|
||||
p_c[c_offset_1],
|
||||
p_c[c_offset_2],
|
||||
p_c[c_offset_3]);
|
||||
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset_0>{}],
|
||||
b_buf[Number<b_offset_1>{}],
|
||||
b_buf[Number<b_offset_2>{}],
|
||||
b_buf[Number<b_offset_3>{}],
|
||||
c_buf(Number<c_offset_0>{}),
|
||||
c_buf(Number<c_offset_1>{}),
|
||||
c_buf(Number<c_offset_2>{}),
|
||||
c_buf(Number<c_offset_3>{}));
|
||||
}
|
||||
else if constexpr(H == 4 && W == 1)
|
||||
{
|
||||
constexpr index_t b_offset_0 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0));
|
||||
constexpr index_t b_offset_1 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0));
|
||||
constexpr index_t b_offset_2 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 2, 0));
|
||||
constexpr index_t b_offset_3 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 3, 0));
|
||||
|
||||
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 0));
|
||||
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 0));
|
||||
constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(e, 0, 2, 0));
|
||||
constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(e, 0, 3, 0));
|
||||
constexpr index_t c_offset_0 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0));
|
||||
constexpr index_t c_offset_1 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0));
|
||||
constexpr index_t c_offset_2 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 2, 0));
|
||||
constexpr index_t c_offset_3 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 3, 0));
|
||||
|
||||
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 0));
|
||||
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 0));
|
||||
constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(k, 0, 2, 0));
|
||||
constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(k, 0, 3, 0));
|
||||
|
||||
amd_assembly_outer_product_1x4(p_a[a_offset],
|
||||
p_b[b_offset_0],
|
||||
p_b[b_offset_1],
|
||||
p_b[b_offset_2],
|
||||
p_b[b_offset_3],
|
||||
p_c[c_offset_0],
|
||||
p_c[c_offset_1],
|
||||
p_c[c_offset_2],
|
||||
p_c[c_offset_3]);
|
||||
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset_0>{}],
|
||||
b_buf[Number<b_offset_1>{}],
|
||||
b_buf[Number<b_offset_2>{}],
|
||||
b_buf[Number<b_offset_3>{}],
|
||||
c_buf(Number<c_offset_0>{}),
|
||||
c_buf(Number<c_offset_1>{}),
|
||||
c_buf(Number<c_offset_2>{}),
|
||||
c_buf(Number<c_offset_3>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, H, 1>{}([&](auto h) {
|
||||
static_for<0, W, 1>{}([&](auto w) {
|
||||
constexpr auto b_offset =
|
||||
BDesc{}.CalculateOffset(make_tuple(e, 0, h, w));
|
||||
constexpr auto c_offset =
|
||||
CDesc{}.CalculateOffset(make_tuple(k, 0, h, w));
|
||||
|
||||
p_c[c_offset] += inner_product_with_conversion<FloatC>{}(p_a[a_offset],
|
||||
p_b[b_offset]);
|
||||
constexpr index_t b_offset =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, h, w));
|
||||
|
||||
constexpr index_t c_offset =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, h, w));
|
||||
|
||||
#if 0
|
||||
c_buf(Number<c_offset>{}) += inner_product_with_conversion<FloatC>{}(
|
||||
a_buf[Number<a_offset>{}], b_buf[Number<b_offset>{}]);
|
||||
#else
|
||||
amd_assembly_inner_product(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset>{}],
|
||||
c_buf(Number<c_offset>{}));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
|
||||
{
|
||||
Run_source(p_a, p_b, p_c);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -5,6 +5,75 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// c += inner_product(a, b)
|
||||
__device__ void amd_assembly_inner_product(const float& a, const float& b, float& c)
|
||||
{
|
||||
#if CK_USE_AMD_V_FMAC_F32
|
||||
asm volatile("\n \
|
||||
v_fmac_f32 %0, %1, %2 \n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#else
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %1, %2 \n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void amd_assembly_inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c)
|
||||
{
|
||||
#if 1
|
||||
asm volatile("\n \
|
||||
v_dot4_i32_i8 %0, %1, %2, %0\n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
|
||||
#else
|
||||
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void amd_assembly_inner_product(const int8x8_t& a, const int8x8_t& b, int32_t& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
amd_assembly_inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
|
||||
c);
|
||||
|
||||
amd_assembly_inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
|
||||
c);
|
||||
}
|
||||
|
||||
__device__ void amd_assembly_inner_product(const int8x16_t& a, const int8x16_t& b, int32_t& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
|
||||
c);
|
||||
|
||||
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
|
||||
c);
|
||||
|
||||
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
|
||||
c);
|
||||
|
||||
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
|
||||
c);
|
||||
}
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
// c1 += inner_product(a, b1)
|
||||
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
|
||||
|
||||
72
composable_kernel/include/utility/buffer.hpp
Normal file
72
composable_kernel/include/utility/buffer.hpp
Normal file
@@ -0,0 +1,72 @@
|
||||
#ifndef CK_BUFFER_HPP
|
||||
#define CK_BUFFER_HPP
|
||||
|
||||
#include "statically_indexed_array.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
||||
{
|
||||
using type = T;
|
||||
using base = StaticallyIndexedArray<T, N>;
|
||||
|
||||
__host__ __device__ constexpr StaticBuffer() : base{} {}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
|
||||
{
|
||||
return StaticBuffer<T, N>{};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct DynamicBuffer
|
||||
{
|
||||
using type = T;
|
||||
|
||||
T* p_data_;
|
||||
|
||||
__host__ __device__ constexpr DynamicBuffer(T* p_data) : p_data_{p_data} {}
|
||||
|
||||
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
|
||||
|
||||
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr const auto Get(index_t i) const
|
||||
{
|
||||
return *reinterpret_cast<const X*>(&p_data_[i]);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void Set(index_t i, const X& x)
|
||||
{
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr auto make_dynamic_buffer(T* p)
|
||||
{
|
||||
return DynamicBuffer<T>{p};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "statically_indexed_array.hpp"
|
||||
#include "container_element_picker.hpp"
|
||||
#include "float_type.hpp"
|
||||
#include "buffer.hpp"
|
||||
#include "functional.hpp"
|
||||
#include "functional2.hpp"
|
||||
#include "functional3.hpp"
|
||||
|
||||
@@ -14,11 +14,11 @@
|
||||
#define CK_DEVICE_BACKEND_AMD 1
|
||||
|
||||
// GPU ID
|
||||
#if 1
|
||||
#if 0
|
||||
#define CK_AMD_GPU_GFX906 1
|
||||
#elif 0
|
||||
#define CK_AMD_GPU_GFX908 1
|
||||
#elif 0
|
||||
#elif 1
|
||||
#define CK_AMD_GPU_GFX1030 1
|
||||
#endif
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
#endif
|
||||
|
||||
// launch bounds
|
||||
#define CK_USE_LAUNCH_BOUNDS 0
|
||||
#define CK_USE_LAUNCH_BOUNDS 1
|
||||
|
||||
#ifdef CK_USE_LAUNCH_BOUNDS
|
||||
#define CK_MAX_THREAD_PER_BLOCK 256
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
#ifndef CK_FLOAT_TYPE_AMD_HPP
|
||||
#define CK_FLOAT_TYPE_AMD_HPP
|
||||
|
||||
#include "statically_indexed_array.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
using half_t = _Float16;
|
||||
@@ -43,6 +45,15 @@ struct vector_type_maker<vector_type<T, N1>, N0>
|
||||
using type = vector_type<T, N0 * N1>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
using vector_type_maker_t = typename vector_type_maker<T, N>::type;
|
||||
|
||||
template <typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_vector_type(Number<N>)
|
||||
{
|
||||
return typename vector_type_maker<T, N>::type{};
|
||||
}
|
||||
|
||||
// scalar_type
|
||||
template <typename TV>
|
||||
struct scalar_type;
|
||||
@@ -403,32 +414,249 @@ struct vector_type<T, 16>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 32>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
typedef T d4_t __attribute__((ext_vector_type(4)));
|
||||
typedef T d8_t __attribute__((ext_vector_type(8)));
|
||||
typedef T d16_t __attribute__((ext_vector_type(16)));
|
||||
typedef T d32_t __attribute__((ext_vector_type(32)));
|
||||
|
||||
using type = d32_t;
|
||||
|
||||
union
|
||||
{
|
||||
d32_t d32_;
|
||||
StaticallyIndexedArray<d1_t, 32> d1x32_;
|
||||
StaticallyIndexedArray<d2_t, 16> d2x16_;
|
||||
StaticallyIndexedArray<d4_t, 8> d4x8_;
|
||||
StaticallyIndexedArray<d8_t, 4> d8x4_;
|
||||
StaticallyIndexedArray<d16_t, 2> d16x2_;
|
||||
StaticallyIndexedArray<d32_t, 1> d32x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x1_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x1_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 64>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
typedef T d4_t __attribute__((ext_vector_type(4)));
|
||||
typedef T d8_t __attribute__((ext_vector_type(8)));
|
||||
typedef T d16_t __attribute__((ext_vector_type(16)));
|
||||
typedef T d32_t __attribute__((ext_vector_type(32)));
|
||||
typedef T d64_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
using type = d64_t;
|
||||
|
||||
union
|
||||
{
|
||||
d64_t d64_;
|
||||
StaticallyIndexedArray<d1_t, 64> d1x64_;
|
||||
StaticallyIndexedArray<d2_t, 32> d2x32_;
|
||||
StaticallyIndexedArray<d4_t, 16> d4x16_;
|
||||
StaticallyIndexedArray<d8_t, 8> d8x8_;
|
||||
StaticallyIndexedArray<d16_t, 4> d16x4_;
|
||||
StaticallyIndexedArray<d32_t, 2> d32x2_;
|
||||
StaticallyIndexedArray<d64_t, 1> d64x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x64_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d64_t>::value)
|
||||
{
|
||||
return data_.d64x1_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x64_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d64_t>::value)
|
||||
{
|
||||
return data_.d64x1_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// fp32
|
||||
using float2_t = typename vector_type<float, 2>::type;
|
||||
using float4_t = typename vector_type<float, 4>::type;
|
||||
using float8_t = typename vector_type<float, 8>::type;
|
||||
using float2_t = typename vector_type<float, 2>::type;
|
||||
using float4_t = typename vector_type<float, 4>::type;
|
||||
using float8_t = typename vector_type<float, 8>::type;
|
||||
using float16_t = typename vector_type<float, 16>::type;
|
||||
using float32_t = typename vector_type<float, 32>::type;
|
||||
using float64_t = typename vector_type<float, 64>::type;
|
||||
|
||||
// fp16
|
||||
using half2_t = typename vector_type<half_t, 2>::type;
|
||||
using half4_t = typename vector_type<half_t, 4>::type;
|
||||
using half8_t = typename vector_type<half_t, 8>::type;
|
||||
using half16_t = typename vector_type<half_t, 16>::type;
|
||||
using half32_t = typename vector_type<half_t, 32>::type;
|
||||
using half64_t = typename vector_type<half_t, 64>::type;
|
||||
|
||||
// bfp16
|
||||
using ushort2_t = typename vector_type<ushort, 2>::type;
|
||||
using ushort4_t = typename vector_type<ushort, 4>::type;
|
||||
using ushort8_t = typename vector_type<ushort, 8>::type;
|
||||
using ushort2_t = typename vector_type<ushort, 2>::type;
|
||||
using ushort4_t = typename vector_type<ushort, 4>::type;
|
||||
using ushort8_t = typename vector_type<ushort, 8>::type;
|
||||
using ushort16_t = typename vector_type<ushort, 16>::type;
|
||||
using ushort32_t = typename vector_type<ushort, 32>::type;
|
||||
using ushort64_t = typename vector_type<ushort, 64>::type;
|
||||
|
||||
// i32
|
||||
using int32x2_t = typename vector_type<int32_t, 2>::type;
|
||||
using int32x4_t = typename vector_type<int32_t, 4>::type;
|
||||
using int32x8_t = typename vector_type<int32_t, 8>::type;
|
||||
using int32x2_t = typename vector_type<int32_t, 2>::type;
|
||||
using int32x4_t = typename vector_type<int32_t, 4>::type;
|
||||
using int32x8_t = typename vector_type<int32_t, 8>::type;
|
||||
using int32x16_t = typename vector_type<int32_t, 16>::type;
|
||||
using int32x32_t = typename vector_type<int32_t, 32>::type;
|
||||
using int32x64_t = typename vector_type<int32_t, 64>::type;
|
||||
|
||||
// i8
|
||||
using int8x2_t = typename vector_type<int8_t, 2>::type;
|
||||
using int8x4_t = typename vector_type<int8_t, 4>::type;
|
||||
using int8x8_t = typename vector_type<int8_t, 8>::type;
|
||||
using int8x16_t = typename vector_type<int8_t, 16>::type;
|
||||
using int8x32_t = typename vector_type<int8_t, 32>::type;
|
||||
using int8x64_t = typename vector_type<int8_t, 64>::type;
|
||||
|
||||
// data type conversion
|
||||
template <typename T>
|
||||
|
||||
@@ -5,11 +5,26 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto make_sequence(Number<Is>...)
|
||||
{
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
|
||||
// F returns index_t
|
||||
template <typename F, index_t N>
|
||||
__host__ __device__ constexpr auto generate_sequence(F, Number<N>)
|
||||
{
|
||||
return typename sequence_gen<N, F>::type{};
|
||||
}
|
||||
|
||||
// F returns Number<>
|
||||
template <typename F, index_t N>
|
||||
__host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number<N>)
|
||||
{
|
||||
return unpack([&f](auto&&... xs) { return make_sequence(f(xs)...); },
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
constexpr auto C0 = C / Number<InWeiVectorSize>{};
|
||||
constexpr auto C1 = Number<InWeiVectorSize>{};
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// run-time variables
|
||||
constexpr auto in_n_hi_wi_c0_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
|
||||
@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 1080;
|
||||
@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
|
||||
print_array("ConvStrides", to_multi_index(ConvStrides{}));
|
||||
print_array("ConvDilations", to_multi_index(ConvDilations{}));
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
using in_data_t = float;
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using acc_data_t = float;
|
||||
@@ -724,23 +724,22 @@ int main(int argc, char* argv[])
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
out_data_t>
|
||||
|
||||
(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
out_data_t>(
|
||||
in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
|
||||
in_vector_size,
|
||||
|
||||
Reference in New Issue
Block a user