mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
reorginzed files
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_KERNEL_WRAPPER
|
||||
#define CK_GRIDWISE_CONVOLUTION_KERNEL_WRAPPER
|
||||
|
||||
template <class GridwiseConvolution, class T>
|
||||
__global__ void run_gridwise_convolution_kernel(const T* const __restrict__ p_in_global,
|
||||
const T* const __restrict__ p_wei_global,
|
||||
T* const __restrict__ p_out_global)
|
||||
{
|
||||
GridwiseConvolution{}.Run(p_in_global, p_wei_global, p_out_global);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,254 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW
|
||||
#define CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_direct_convolution.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t CPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead>
|
||||
struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_kcyx_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_nchw_global_desc.GetLength(I0);
|
||||
constexpr index_t K = wei_kcyx_global_desc.GetLength(I0);
|
||||
constexpr index_t C = wei_kcyx_global_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_kcyx_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_kcyx_global_desc.GetLength(I3);
|
||||
|
||||
constexpr auto wei_ke_global_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
constexpr auto in_nchw_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{},
|
||||
Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_ke_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<KPerBlock, CPerBlock * Y * X>{},
|
||||
Number<WeiBlockCopyDataPerRead>{}); // 2d view of wei for blockwise copy
|
||||
|
||||
constexpr auto wei_kcyx_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, Y, X>{},
|
||||
Sequence<wei_ke_block_desc.GetStride(I0), Y * X, X, 1>{});
|
||||
|
||||
// shared mem
|
||||
constexpr index_t in_block_element_size =
|
||||
in_nchw_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
|
||||
constexpr index_t wei_block_element_size =
|
||||
wei_kcyx_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
|
||||
__shared__ Float
|
||||
p_in_block[max_align * ((in_block_element_size + max_align - 1) / max_align)];
|
||||
__shared__ Float
|
||||
p_wei_block[max_align * ((wei_block_element_size + max_align - 1) / max_align)];
|
||||
|
||||
// threadwise tensors
|
||||
constexpr index_t HiPerThread = HoPerThread + Y - 1;
|
||||
constexpr index_t WiPerThread = WoPerThread + X - 1;
|
||||
|
||||
constexpr auto in_nchw_thread_block_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
|
||||
in_nchw_block_desc.GetStrides());
|
||||
|
||||
constexpr auto wei_kcyx_thread_block_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, CPerThread, Y, X>{}, wei_kcyx_block_desc.GetStrides());
|
||||
|
||||
constexpr auto out_nkhw_thread_desc = get_convolution_output_default_4d_tensor_descriptor(
|
||||
in_nchw_thread_block_desc, wei_kcyx_thread_block_desc);
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
|
||||
|
||||
// divide block work
|
||||
constexpr index_t NBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
|
||||
constexpr index_t KBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
|
||||
|
||||
const index_t block_id = blockIdx.x;
|
||||
|
||||
index_t itmp = block_id;
|
||||
const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
|
||||
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
|
||||
const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
|
||||
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
|
||||
const index_t h_block_work_id = itmp / WBlockWork;
|
||||
const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
|
||||
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin; // minus padding
|
||||
const index_t wi_block_data_begin = wo_block_data_begin; // minus padding
|
||||
|
||||
// divide thread work
|
||||
constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
|
||||
constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
|
||||
constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
|
||||
constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
|
||||
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
|
||||
itmp = thread_id;
|
||||
const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
|
||||
itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork);
|
||||
const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork);
|
||||
itmp -= k_thread_work_id * (HThreadWork * WThreadWork);
|
||||
const index_t h_thread_work_id = itmp / WThreadWork;
|
||||
const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
|
||||
|
||||
const index_t n_thread_data_begin = n_thread_work_id * NPerThread;
|
||||
const index_t k_thread_data_begin = k_thread_work_id * KPerThread;
|
||||
const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread;
|
||||
const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread;
|
||||
|
||||
const index_t hi_thread_data_begin = ho_thread_data_begin;
|
||||
const index_t wi_thread_data_begin = wo_thread_data_begin;
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_nchw_global_desc),
|
||||
decltype(in_nchw_block_desc),
|
||||
decltype(in_nchw_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead>{};
|
||||
|
||||
#if 0
|
||||
constexpr auto blockwise_wei_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_kcyx_global_desc),
|
||||
decltype(wei_kcyx_block_desc),
|
||||
decltype(wei_kcyx_block_desc.GetLengths()),
|
||||
1>{};
|
||||
#elif 1
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ke_global_desc),
|
||||
decltype(wei_ke_block_desc),
|
||||
decltype(wei_ke_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>({0, 0}, {0, 0});
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
blockwise_in_copy.Run(
|
||||
p_in_global +
|
||||
in_nchw_global_desc.GetOffsetFromMultiIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
p_in_block);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_wei_copy.Run(p_wei_global +
|
||||
wei_kcyx_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(index_t c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
|
||||
{
|
||||
// threadwise convolution
|
||||
#if 1
|
||||
threadwise_direct_convolution_2(
|
||||
in_nchw_thread_block_desc,
|
||||
p_in_block +
|
||||
in_nchw_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_thread_block_desc,
|
||||
p_wei_block +
|
||||
wei_kcyx_block_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_begin, c_thread_data, 0, 0),
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread);
|
||||
#elif 0
|
||||
threadwise_direct_convolution_3(
|
||||
in_nchw_thread_block_desc,
|
||||
p_in_block +
|
||||
in_nchw_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_thread_block_desc,
|
||||
p_wei_block +
|
||||
wei_kcyx_block_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_begin, c_thread_data, 0, 0),
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// copy output tensor from register to global mem
|
||||
threadwise_tensor_slice_copy(out_nkhw_thread_desc,
|
||||
p_out_thread,
|
||||
out_nkhw_global_desc,
|
||||
p_out_global +
|
||||
out_nkhw_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_nkhw_thread_desc.GetLengths(),
|
||||
Number<1>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,399 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyClusterLengths_CHWN,
|
||||
index_t InBlockCopyDataPerRead_N,
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
|
||||
|
||||
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
|
||||
const index_t w_block_work_id = itmp / NBlockWork;
|
||||
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// flattend (2d) tensor view of gridwise weight
|
||||
constexpr auto wei_cyx_k_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
|
||||
|
||||
// tensor view of blockwise input and weight in LDS
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{},
|
||||
Number<InBlockCopyDataPerRead_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_cyx_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock * Y * X, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
constexpr auto wei_c_y_x_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, Y, X, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
const auto blockwise_in_copy =
|
||||
#if 0
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#else
|
||||
Blockwise4dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock*Y*X,KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_cyx_k_global_desc),
|
||||
decltype(wei_cyx_k_block_desc),
|
||||
decltype(wei_cyx_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<KPerBlock>{},
|
||||
Number<wei_c_y_x_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
wei_c_y_x_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
{
|
||||
#if 1
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
|
||||
#else
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 1
|
||||
blockwise_batch_gemm.Run
|
||||
#else
|
||||
blockwise_batch_gemm.Run_asm
|
||||
#endif
|
||||
(p_wei_block + wei_c_y_x_k_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_in_block + in_c_h_w_n_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
|
||||
// perfect forwarding.
|
||||
// Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert(
|
||||
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_tensor_slice_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
}).Else([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
|
||||
for(index_t i = 0; i < 64; ++i)
|
||||
{
|
||||
printf("out %f, ", p_out_thread[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_tensor_slice_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,435 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R2_CHWN_CYXK_KHWN
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R2_CHWN_CYXK_KHWN
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_3d_tensor_op.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyClusterLengths_CHWN,
|
||||
index_t InBlockCopyDataPerRead_N,
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
|
||||
|
||||
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
|
||||
const index_t w_block_work_id = itmp / NBlockWork;
|
||||
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_x_k_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<C, X, K>{}, Sequence<Y * X * K, K, 1>{});
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WiPerBlock, NPerBlock>{},
|
||||
Number<InBlockCopyDataPerRead_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_x_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, X, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
#if 1
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#else
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise3dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_x_k_global_desc),
|
||||
decltype(wei_c_x_k_block_desc),
|
||||
decltype(wei_c_x_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_x_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space =
|
||||
wei_c_x_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
|
||||
|
||||
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
#if 1
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
blockwise_in_copy.Run(
|
||||
p_in_global_block_offset +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(0, y, 0, 0),
|
||||
p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global_block_offset +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, 0, 0),
|
||||
p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_batch_gemm.Run(
|
||||
p_wei_block + wei_c_x_k_block_desc.GetOffsetFromMultiIndex(0, x, 0),
|
||||
p_in_block + in_c_h_w_n_block_desc.GetOffsetFromMultiIndex(0, 0, x, 0),
|
||||
p_out_thread);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
#else
|
||||
// this use much more register, haven't figure out why?
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
0, hi_block_data_begin + y, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, 0, k_block_data_begin);
|
||||
|
||||
for(index_t
|
||||
c_block_data_begin = 0;
|
||||
c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_batch_gemm.Run(
|
||||
p_wei_block + wei_c_x_k_block_desc.GetOffsetFromMultiIndex(0, x, 0),
|
||||
p_in_block + in_c_h_w_n_block_desc.GetOffsetFromMultiIndex(0, 0, x, 0),
|
||||
p_out_thread);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
|
||||
// perfect forwarding.
|
||||
// Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert(
|
||||
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_tensor_slice_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
}).Else([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
|
||||
for(index_t i = 0; i < 64; ++i)
|
||||
{
|
||||
printf("out %f, ", p_out_thread[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_tensor_slice_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,425 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyClusterLengths_CHWN,
|
||||
index_t InBlockCopyDataPerRead_N,
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
|
||||
|
||||
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
|
||||
constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock);
|
||||
constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock);
|
||||
constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock);
|
||||
|
||||
constexpr auto block_work_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
|
||||
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
|
||||
const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
|
||||
const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_k_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
|
||||
Number<InBlockCopyDataPerRead_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// choose GEMM implementation here
|
||||
const auto run_blockwise_batch_gemm = [&](auto... Xs) {
|
||||
#if 0
|
||||
return blockwise_batch_gemm.Run(Xs...);
|
||||
#elif 0
|
||||
return blockwise_batch_gemm.Run_asm(Xs...);
|
||||
#else
|
||||
return blockwise_batch_gemm.Run_asm_v2(Xs...);
|
||||
#endif
|
||||
};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
// TODO:: need to properly implement tensor descriptor with alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
|
||||
|
||||
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
#if 1
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_in_copy.Run(
|
||||
p_in_global_block_offset +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global_block_offset +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset +=
|
||||
CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset +=
|
||||
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
|
||||
// perfect forwarding.
|
||||
// Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert(
|
||||
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_tensor_slice_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
}).Else([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
|
||||
for(index_t i = 0; i < 64; ++i)
|
||||
{
|
||||
printf("out %f, ", p_out_thread[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_tensor_slice_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,475 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyClusterLengths_CHWN,
|
||||
index_t InBlockCopyDataPerRead_N,
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
|
||||
|
||||
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
// assert for LDS double buffer
|
||||
static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided");
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock);
|
||||
constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock);
|
||||
constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock);
|
||||
constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KBlockWork, HBlockWork, WBlockWork, NBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock;
|
||||
const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock;
|
||||
const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
|
||||
Number<InBlockCopyDataPerRead_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
#if 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#else
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>({0, 0}, {0, 0});
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// choose GEMM implementation here
|
||||
const auto run_blockwise_batch_gemm = [&](auto... Xs) {
|
||||
#if 1
|
||||
return blockwise_batch_gemm.Run(Xs...);
|
||||
#elif 0
|
||||
return blockwise_batch_gemm.Run_asm(Xs...);
|
||||
#else
|
||||
return blockwise_batch_gemm.Run_asm_v2(Xs...);
|
||||
#endif
|
||||
};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
// LDS double buffer
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
|
||||
|
||||
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output to 0
|
||||
threadwise_matrix_set_zero(c_k_wn_thread_mtx_desc, p_out_thread);
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_double);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C;
|
||||
c_block_data_begin += 2 * CPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float
|
||||
p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
p_in_global_block_offset +=
|
||||
CPerBlock * in_c_h_w_n_global_desc.GetStride(I0);
|
||||
p_wei_global_block_offset +=
|
||||
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
run_blockwise_batch_gemm(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
// even iteration
|
||||
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0);
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_batch_gemm(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(
|
||||
p_wei_register_clipboard, p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_batch_gemm(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto fwd) {
|
||||
// fwd do nothing but perfect forwarding.
|
||||
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
|
||||
// being instantiated here
|
||||
static_assert(
|
||||
(fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = fwd(out_k_h_w_n_global_desc)
|
||||
.Fold(I3, Number<N1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<W2>{})
|
||||
.Fold(I0, Number<K1>{}, Number<K2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"a: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"a: out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_tensor_slice_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / fwd(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
fwd(out_k_h_w_n_global_desc)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<W2>{}, Number<W3>{})
|
||||
.Fold(I0, Number<K1>{}, Number<K2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc =
|
||||
fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{}, Number<W3>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"b: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"b: out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_tensor_slice_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,451 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_op.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockReorderSrcSubLengths_NCHW,
|
||||
class InBlockReorderSrcClusterLengths_NCHW,
|
||||
class InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
index_t InBlockReorderDataPerRead_W,
|
||||
index_t InBlockReorderDataPerWrite_N,
|
||||
class WeiBlockCopyClusterLengths_CK, // not used
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_W>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
|
||||
constexpr index_t N = out_n_k_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
// divide block work: [N, K, Ho, Wo]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
|
||||
constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock);
|
||||
constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock);
|
||||
constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock);
|
||||
|
||||
constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
|
||||
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
|
||||
const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
|
||||
const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_k_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockReorderDataPerWrite_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
|
||||
Number<InBlockReorderDataPerWrite_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
|
||||
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder = BlockwiseTensorSliceReorderCopy_v3<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>({0, 0, 0, 0}, {0, 0, 0, 0});
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>({0, 0}, {0, 0});
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// choose GEMM implementation here
|
||||
const auto run_blockwise_batch_gemm = [&](auto... Xs) {
|
||||
#if 1
|
||||
return blockwise_batch_gemm.Run(Xs...);
|
||||
#elif 0
|
||||
return blockwise_batch_gemm.Run_asm(Xs...);
|
||||
#else
|
||||
return blockwise_batch_gemm.Run_asm_v2(Xs...);
|
||||
#endif
|
||||
};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc");
|
||||
|
||||
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_generic_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
#if 0
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_in_copy_reorder.Run(p_in_global_block_offset +
|
||||
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x),
|
||||
p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset +=
|
||||
CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
|
||||
p_wei_global_block_offset +=
|
||||
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
blockwise_in_copy_reorder.Run(p_in_global_block_offset, p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto fwd) {
|
||||
// fwd do nothing but perfect forwarding.
|
||||
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
|
||||
// begin instantiated here
|
||||
static_assert(
|
||||
(fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = fwd(out_n_k_h_w_global_desc)
|
||||
.Fold(I3, Number<W1>{}, Number<W2>{})
|
||||
.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"a: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
|
||||
"a: out_n_k_h_w_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{};
|
||||
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
map_out_global2thread);
|
||||
// Number<OutThreadCopyDataPerWrite_W>{});
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / fwd(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
fwd(out_n_k_h_w_global_desc)
|
||||
.Fold(I3, Number<W1>{}, Number<W2>{}, Number<W3>{})
|
||||
.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc =
|
||||
fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{}, Number<W3>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"b: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
|
||||
"b: out_n_k_h_w_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{};
|
||||
|
||||
#if 0
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
map_out_global2thread,
|
||||
Number<OutThreadCopyDataPerWrite_W>{});
|
||||
#else
|
||||
threadwise_generic_tensor_slice_copy_v1(
|
||||
out_10d_thread_desc.ReorderGivenNew2Old(map_out_global2thread),
|
||||
p_out_thread,
|
||||
make_zero_array<index_t, 10>(),
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
make_zero_array<index_t, 10>(),
|
||||
out_10d_thread_desc.GetLengths().ReorderGivenNew2Old(map_out_global2thread),
|
||||
arithmetic_sequence_gen<0, 10, 1>::SeqType{},
|
||||
Number<1>{});
|
||||
#endif
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,502 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_op.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockReorderSrcSubLengths_NCHW,
|
||||
class InBlockReorderSrcClusterLengths_NCHW,
|
||||
class InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
index_t InBlockReorderDataPerRead_W,
|
||||
index_t InBlockReorderDataPerWrite_N,
|
||||
class WeiBlockCopyClusterLengths_CK, // not used
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_W>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
|
||||
constexpr index_t N = out_n_k_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
// assert for LDS double buffer
|
||||
static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided");
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
|
||||
constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock);
|
||||
constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock);
|
||||
constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock);
|
||||
|
||||
constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
|
||||
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
|
||||
const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
|
||||
const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockReorderDataPerWrite_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
|
||||
Number<InBlockReorderDataPerWrite_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment requirements
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
|
||||
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder = BlockwiseTensorSliceReorderCopy_v3<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>({0, 0, 0, 0}, {0, 0, 0, 0});
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>({0, 0}, {0, 0});
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// choose GEMM implementation here
|
||||
const auto run_blockwise_batch_gemm = [&](auto... Xs) {
|
||||
#if 1
|
||||
return blockwise_batch_gemm.Run(Xs...);
|
||||
#elif 0
|
||||
return blockwise_batch_gemm.Run_asm(Xs...);
|
||||
#else
|
||||
return blockwise_batch_gemm.Run_asm_v2(Xs...);
|
||||
#endif
|
||||
};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
// LDS double buffer
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc");
|
||||
|
||||
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_generic_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
Float p_in_register_clipboard[blockwise_in_copy_reorder
|
||||
.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_double);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C;
|
||||
c_block_data_begin += 2 * CPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_in_register_clipboard[blockwise_in_copy_reorder
|
||||
.GetRegisterClipboardSize()];
|
||||
Float
|
||||
p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
p_in_global_block_offset +=
|
||||
CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
|
||||
p_wei_global_block_offset +=
|
||||
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_batch_gemm(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
Float p_in_register_clipboard[blockwise_in_copy_reorder
|
||||
.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
// even iteration
|
||||
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_batch_gemm(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy_reorder.RunStoreRegisterClipboard(
|
||||
p_in_register_clipboard, p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(
|
||||
p_wei_register_clipboard, p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_batch_gemm(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto fwd) {
|
||||
// fwd do nothing but perfect forwarding.
|
||||
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
|
||||
// begin instantiated here
|
||||
static_assert(
|
||||
(fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = fwd(out_n_k_h_w_global_desc)
|
||||
.Fold(I3, Number<W1>{}, Number<W2>{})
|
||||
.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"a: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
|
||||
"a: out_n_k_h_w_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{};
|
||||
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
map_out_global2thread);
|
||||
// Number<OutThreadCopyDataPerWrite_W>{});
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / fwd(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
fwd(out_n_k_h_w_global_desc)
|
||||
.Fold(I3, Number<W1>{}, Number<W2>{}, Number<W3>{})
|
||||
.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc =
|
||||
fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{}, Number<W3>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"b: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_n_k_h_w_global_desc,
|
||||
"b: out_n_k_h_w_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{};
|
||||
|
||||
#if 0
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
map_out_global2thread,
|
||||
Number<OutThreadCopyDataPerWrite_W>{});
|
||||
#else
|
||||
threadwise_generic_tensor_slice_copy_v1(
|
||||
out_10d_thread_desc.ReorderGivenNew2Old(map_out_global2thread),
|
||||
p_out_thread,
|
||||
make_zero_array<index_t, 10>(),
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_n_k_h_w_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
make_zero_array<index_t, 10>(),
|
||||
out_10d_thread_desc.GetLengths().ReorderGivenNew2Old(map_out_global2thread),
|
||||
arithmetic_sequence_gen<0, 10, 1>::SeqType{},
|
||||
Number<1>{});
|
||||
#endif
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
#endif
|
||||
@@ -0,0 +1,284 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// define B = flatten(N, Hi, Wi)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t BPerThread,
|
||||
index_t KPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
index_t InBlockCopyThreadPerDim0,
|
||||
index_t InBlockCopyThreadPerDim1,
|
||||
index_t WeiBlockCopyThreadPerDim0,
|
||||
index_t WeiBlockCopyThreadPerDim1,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
index_t OutThreadCopyDataPerWrite>
|
||||
struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_chwn_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Hi = in_chwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wi = in_chwn_global_desc.GetLength(I2);
|
||||
constexpr index_t N = in_chwn_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t B = N * Hi * Wi;
|
||||
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
|
||||
|
||||
// divide block work by 2d: [K, B]
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / BBlockWork;
|
||||
const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t b_block_data_begin = b_block_work_id * BPerBlock;
|
||||
|
||||
// flattend (2d) tensor view of gridwise input
|
||||
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
|
||||
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
|
||||
|
||||
// tensor view of blockwise input and weight
|
||||
// be careful of alignment
|
||||
constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, BPerBlock + BGhostRead>{}, Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock * Y * X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, Y, X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_kb_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerThread, BPerThread>{});
|
||||
|
||||
// blockwise in copy
|
||||
// formmat is [CPerBlock,BPerBlock + BGhostRead]
|
||||
#if 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>{};
|
||||
#elif 1
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock*Y*X,KPerBlock]
|
||||
#if 0
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
#elif 1
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// a series of blockwise GEMM
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
|
||||
// a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K]
|
||||
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
|
||||
// c_mtx[K,B] is out_block[K,B]
|
||||
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_cyxk_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<BPerBlock>{}, Number<in_cb_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_kxb_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<BPerThread>{});
|
||||
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
decltype(c_kxb_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t max_align =
|
||||
math::lcm(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
|
||||
|
||||
constexpr index_t in_block_space = in_cb_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
wei_cyxk_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_cb_global_desc.GetOffsetFromMultiIndex(0, b_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_cyxk_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
|
||||
|
||||
// set threadwise output to 0
|
||||
threadwise_matrix_set_zero(c_kxb_thread_mtx_desc, p_out_thread);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
{
|
||||
// load data
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// compute on current data
|
||||
// a series of GEMM
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 1
|
||||
blockwise_gemm.Run
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#elif 1
|
||||
blockwise_gemm.Run_asm
|
||||
#endif
|
||||
(p_wei_block + wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_in_block + y * Wi + x,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
|
||||
const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
|
||||
|
||||
for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
|
||||
{
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
|
||||
|
||||
index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row;
|
||||
index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col;
|
||||
|
||||
index_t h_data = b_data / (Wi * N);
|
||||
index_t itmp = b_data - h_data * (Wi * N);
|
||||
index_t w_data = itmp / N;
|
||||
index_t n_data = itmp - w_data * N;
|
||||
|
||||
if(n_data < N && h_data < Ho && w_data < Wo)
|
||||
{
|
||||
p_out_global[out_khwn_global_desc.GetOffsetFromMultiIndex(
|
||||
k_data, h_data, w_data, n_data)] =
|
||||
p_out_thread[out_kb_thread_desc.GetOffsetFromMultiIndex(k, b)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,413 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// define B = flatten(N, Hi, Wi)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t BPerThread,
|
||||
index_t KPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
index_t InBlockCopyThreadPerDim0,
|
||||
index_t InBlockCopyThreadPerDim1,
|
||||
index_t WeiBlockCopyThreadPerDim0,
|
||||
index_t WeiBlockCopyThreadPerDim1,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
index_t OutThreadCopyDataPerWrite>
|
||||
struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_chwn_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Hi = in_chwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wi = in_chwn_global_desc.GetLength(I2);
|
||||
constexpr index_t N = in_chwn_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t B = N * Hi * Wi;
|
||||
constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1);
|
||||
|
||||
// assert for LDS double buffer
|
||||
static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided");
|
||||
|
||||
// divide block work by 2d: [K, B]
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / BBlockWork;
|
||||
const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t b_block_data_begin = b_block_work_id * BPerBlock;
|
||||
|
||||
// flattend (2d) tensor view of gridwise input
|
||||
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
|
||||
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
|
||||
|
||||
// tensor view of blockwise input and weight
|
||||
// be careful of alignment
|
||||
constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, BPerBlock + BGhostRead>{}, Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock * Y * X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, Y, X, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_kb_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerThread, BPerThread>{});
|
||||
|
||||
// blockwise in copy
|
||||
// formmat is [CPerBlock,BPerBlock + BGhostRead]
|
||||
#if 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>{};
|
||||
#elif 1
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock*Y*X,KPerBlock]
|
||||
#if 0
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
#elif 1
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// a series of blockwise GEMM
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
|
||||
// a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K]
|
||||
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
|
||||
// c_mtx[K,B] is out_block[K,B]
|
||||
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_cyxk_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<BPerBlock>{}, Number<in_cb_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_kxb_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<BPerThread>{});
|
||||
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
decltype(c_kxb_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t max_align =
|
||||
math::lcm(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
|
||||
|
||||
constexpr index_t in_block_space = in_cb_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
wei_cyxk_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
// LDS double buffer
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_cb_global_desc.GetOffsetFromMultiIndex(0, b_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_cyxk_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
// preload data into LDS
|
||||
{
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_double);
|
||||
}
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
|
||||
|
||||
// set threadwise output to 0
|
||||
threadwise_matrix_set_zero(c_kxb_thread_mtx_desc, p_out_thread);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C;
|
||||
c_block_data_begin += 2 * CPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
// load next data
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
// compute on current data
|
||||
// a series of GEMM
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 1
|
||||
blockwise_gemm.Run
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#elif 0
|
||||
blockwise_gemm.Run_asm
|
||||
#endif
|
||||
(p_wei_block_now +
|
||||
wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
// even
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 1
|
||||
blockwise_gemm.Run
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#elif 0
|
||||
blockwise_gemm.Run_asm
|
||||
#endif
|
||||
(p_wei_block_double +
|
||||
wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_in_block_double + y * Wi + x,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd
|
||||
__syncthreads();
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 1
|
||||
blockwise_gemm.Run
|
||||
#elif 0
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#elif 0
|
||||
blockwise_gemm.Run_asm
|
||||
#endif
|
||||
(p_wei_block_double + wei_block_space +
|
||||
wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_in_block_double + in_block_space + y * Wi + x,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row;
|
||||
const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col;
|
||||
|
||||
if(Y == 1 && X == 1)
|
||||
{ // pure 1x1 conv (non padding, 1x1 stride)
|
||||
constexpr index_t K2_ = GemmMPerThreadSubC;
|
||||
constexpr index_t K1_ = KPerBlock / KPerThread;
|
||||
constexpr index_t B2_ = GemmNPerThreadSubC;
|
||||
constexpr index_t B1_ = BPerBlock / BPerThread;
|
||||
|
||||
constexpr auto out_6d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1_ * K2_), K1_, K2_, B / (B1_ * B2_), B1_, B2_>{});
|
||||
|
||||
constexpr auto out_6d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerBlock / (K1_ * K2_), 1, K2_, BPerBlock / (B1_ * B2_), 1, B2_>{});
|
||||
|
||||
constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{});
|
||||
|
||||
threadwise_6d_tensor_copy(out_6d_thread_desc,
|
||||
p_out_thread,
|
||||
out_6d_global_desc,
|
||||
p_out_global +
|
||||
out_kb_global_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_begin, b_thread_data_begin),
|
||||
out_6d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
|
||||
{
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b);
|
||||
|
||||
index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row;
|
||||
index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col;
|
||||
|
||||
index_t h_data = b_data / (Wi * N);
|
||||
index_t itmp = b_data - h_data * (Wi * N);
|
||||
index_t w_data = itmp / N;
|
||||
index_t n_data = itmp - w_data * N;
|
||||
|
||||
if(n_data < N && h_data < Ho && w_data < Wo)
|
||||
{
|
||||
p_out_global[out_khwn_global_desc.GetOffsetFromMultiIndex(
|
||||
k_data, h_data, w_data, n_data)] =
|
||||
p_out_thread[out_kb_thread_desc.GetOffsetFromMultiIndex(k, b)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,377 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// define B = merge(N0, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t N1,
|
||||
index_t N2,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_C_N1_B_N2,
|
||||
class InBlockCopyClusterLengths_C_N1_B_N2,
|
||||
index_t InBlockCopySrcDataPerRead_B,
|
||||
index_t InBlockCopyDstDataPerWrite_N2,
|
||||
class WeiBlockCopySubLengths_C_K,
|
||||
class WeiBlockCopyClusterLengths_C_K,
|
||||
index_t WeiBlockCopyDataPerAccess_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
static_assert(N2 == GemmNPerThreadSubC, "wrong!");
|
||||
static_assert((N1 * N2 * BPerBlock) %
|
||||
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr auto False = integral_constant<bool, false>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
|
||||
|
||||
constexpr index_t N0 = N / (N1 * N2);
|
||||
|
||||
constexpr index_t B = N0 * Ho * Wo;
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && C % CPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// memory layout descriptor in device memory [N0, N1, N2, C, H, W]
|
||||
constexpr auto in_n0_n1_n2_c_h_w_global_mem_desc =
|
||||
in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// merged tensor descriptor in device memory [C, N1, B, N2], src of blockwise copy
|
||||
constexpr auto in_c_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
in_n0_n1_n2_c_h_w_global_mem_desc.Slice(I4, Number<Ho>{}).Slice(I5, Number<Wo>{}),
|
||||
Sequence<3>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 4, 5>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// memory layout descriptor in LDS [C, N1, B, N2], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_c_n1_b_n2_block_mem_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_n1_b_n2_block_mem_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not satisfied");
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_c_n1_b_n2_global_merged_desc),
|
||||
decltype(in_c_n1_b_n2_block_mem_desc),
|
||||
decltype(in_c_n1_b_n2_block_mem_desc.GetLengths()),
|
||||
InBlockCopySubLengths_C_N1_B_N2,
|
||||
InBlockCopyClusterLengths_C_N1_B_N2,
|
||||
Sequence<0, 1, 3, 2>, // thread_arrange_order [C, N1, N2, B]
|
||||
Sequence<1, 3, 0, 2>, // src_access_order [N1, N2, C, B]
|
||||
Sequence<0, 1, 2, 3>, // dst_access_order [C, N1, B, N2]
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerAccess_K, GemmDataPerReadA)>{});
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_C_K,
|
||||
WeiBlockCopyClusterLengths_C_K,
|
||||
Sequence<0, 1>, // thread_arrange_order [C, K]
|
||||
Sequence<0, 1>, // src_access_order [C, K]
|
||||
Sequence<0, 1>, // dst_access_order [C, K]
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
WeiBlockCopyDataPerAccess_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[CPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[CPerBlocl, N1 * BPerBlock * N2] is in LDS
|
||||
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_n1bn2_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<N1 * BPerBlock * N2>{},
|
||||
Number<in_c_n1_b_n2_block_mem_desc.GetStride(I0)>{});
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_n1bn2_block_mtx_desc),
|
||||
decltype(c_k0k2_n1n2_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// choose GEMM implementation here
|
||||
const auto run_blockwise_gemm = [&](auto... Xs) {
|
||||
#if 1
|
||||
return blockwise_gemm.Run(Xs...);
|
||||
#else
|
||||
return blockwise_gemm.Run_asm(Xs...);
|
||||
#endif
|
||||
};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
in_c_n1_b_n2_block_mem_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
|
||||
|
||||
#if 0
|
||||
// do work
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
// calculate origin of block input and weight tensor on global memory
|
||||
const Float* p_in_block_on_global =
|
||||
p_in_global + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x);
|
||||
|
||||
const Float* p_wei_block_on_global =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0);
|
||||
|
||||
for(index_t
|
||||
c_block_data_on_global = 0;
|
||||
c_block_data_on_global < C;
|
||||
c_block_data_on_global += CPerBlock,
|
||||
p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
|
||||
p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_block_on_global, p_in_block);
|
||||
blockwise_wei_copy.Run(p_wei_block_on_global, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
run_blockwise_gemm(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
// calculate origin of block input and weight tensor on global memory
|
||||
const Float* p_in_block_on_global =
|
||||
p_in_global + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x);
|
||||
|
||||
const Float* p_wei_block_on_global =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0);
|
||||
|
||||
for(index_t c_block_data_on_global = 0; c_block_data_on_global < C;
|
||||
c_block_data_on_global += CPerBlock)
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_block_on_global, p_in_block);
|
||||
blockwise_wei_copy.Run(p_wei_block_on_global, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(
|
||||
I0, Number<CPerBlock>{}, True);
|
||||
|
||||
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(
|
||||
I0, Number<CPerBlock>{}, True);
|
||||
}
|
||||
|
||||
// reset C
|
||||
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<C>{}, False);
|
||||
|
||||
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<C>{}, False);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t K0 = K / (K1 * K2);
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register
|
||||
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc =
|
||||
make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
|
||||
|
||||
// output tensor descriptor in register, src of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc =
|
||||
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old(
|
||||
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
|
||||
|
||||
// output memory layout descriptor in device memory, dst of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
// output merged global tensor descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
|
||||
Sequence<3>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 4, 5>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// origin of dst in device memory
|
||||
Float* p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
|
||||
|
||||
threadwise_generic_tensor_slice_copy_v1(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
|
||||
p_out_thread,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc,
|
||||
p_out_thread_on_global,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
|
||||
arithmetic_sequence_gen<0, 8, 1>::SeqType{},
|
||||
Number<1>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,404 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// define B = merge(N0, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t N1,
|
||||
index_t N2,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_C_N1_B_N2,
|
||||
class InBlockCopyClusterLengths_C_N1_B_N2,
|
||||
index_t InBlockCopySrcDataPerRead_B,
|
||||
index_t InBlockCopyDstDataPerWrite_N2,
|
||||
class WeiBlockCopySubLengths_C_K,
|
||||
class WeiBlockCopyClusterLengths_C_K,
|
||||
index_t WeiBlockCopyDataPerAccess_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
static_assert(N2 == GemmNPerThreadSubC, "wrong!");
|
||||
static_assert((N1 * N2 * BPerBlock) %
|
||||
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
|
||||
|
||||
constexpr index_t N0 = N / (N1 * N2);
|
||||
|
||||
constexpr index_t B = N0 * Ho * Wo;
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && C % (2 * CPerBlock) == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// memory layout descriptor in device memory [N0, N1, N2, C, H, W]
|
||||
constexpr auto in_n0_n1_n2_c_h_w_global_mem_desc =
|
||||
in_n_c_h_w_global_desc.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// merged tensor descriptor in device memory [C, N1, B, N2], src of blockwise copy
|
||||
constexpr auto in_c_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
in_n0_n1_n2_c_h_w_global_mem_desc.Slice(I4, Number<Ho>{}).Slice(I5, Number<Wo>{}),
|
||||
Sequence<3>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 4, 5>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// memory layout descriptor in LDS [C, N1, B, N2], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_c_n1_b_n2_block_mem_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_n1_b_n2_block_mem_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not satisfied");
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
const auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_c_n1_b_n2_global_merged_desc),
|
||||
decltype(in_c_n1_b_n2_block_mem_desc),
|
||||
decltype(in_c_n1_b_n2_block_mem_desc.GetLengths()),
|
||||
InBlockCopySubLengths_C_N1_B_N2,
|
||||
InBlockCopyClusterLengths_C_N1_B_N2,
|
||||
Sequence<0, 1, 3, 2>, // thread_arrange_order [C, N1, N2, B]
|
||||
Sequence<1, 3, 0, 2>, // src_access_order [N1, N2, C, B]
|
||||
Sequence<0, 1, 2, 3>, // dst_access_order [C, N1, B, N2]
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDataPerAccess_K, GemmDataPerReadA)>{});
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
const auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_C_K,
|
||||
WeiBlockCopyClusterLengths_C_K,
|
||||
Sequence<0, 1>, // thread_arrange_order [C, K]
|
||||
Sequence<0, 1>, // src_access_order [C, K]
|
||||
Sequence<0, 1>, // dst_access_order [C, K]
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
WeiBlockCopyDataPerAccess_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[CPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[CPerBlocl, N1 * BPerBlock * N2] is in LDS
|
||||
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_n1bn2_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<N1 * BPerBlock * N2>{},
|
||||
Number<in_c_n1_b_n2_block_mem_desc.GetStride(I0)>{});
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_n1bn2_block_mtx_desc),
|
||||
decltype(c_k0k2_n1n2_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// choose GEMM implementation here
|
||||
const auto run_blockwise_gemm = [&](auto... Xs) {
|
||||
#if 1
|
||||
return blockwise_gemm.Run(Xs...);
|
||||
#else
|
||||
return blockwise_gemm.Run_asm(Xs...);
|
||||
#endif
|
||||
};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
in_c_n1_b_n2_block_mem_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
|
||||
|
||||
// do work
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
// calculate origin of block input and weight tensor on global memory
|
||||
const Float* p_in_block_on_global =
|
||||
p_in_global + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x);
|
||||
|
||||
const Float* p_wei_block_on_global =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_block_on_global, p_in_block_double);
|
||||
blockwise_wei_copy.Run(p_wei_block_on_global, p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C;
|
||||
c_block_data_begin += 2 * CPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float
|
||||
p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
|
||||
p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_block_on_global,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_gemm(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
// even iteration
|
||||
p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
|
||||
p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_block_on_global,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_gemm(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(
|
||||
p_wei_register_clipboard, p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_gemm(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t K0 = K / (K1 * K2);
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register
|
||||
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc =
|
||||
make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
|
||||
|
||||
// output tensor descriptor in register, src of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc =
|
||||
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old(
|
||||
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
|
||||
|
||||
// output memory layout descriptor in device memory, dst of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
// output merged global tensor descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
|
||||
Sequence<3>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 4, 5>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// origin of dst in device memory
|
||||
Float* p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
|
||||
|
||||
threadwise_generic_tensor_slice_copy_v1(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
|
||||
p_out_thread,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc,
|
||||
p_out_thread_on_global,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
|
||||
arithmetic_sequence_gen<0, 8, 1>::SeqType{},
|
||||
Number<1>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namesspace ck
|
||||
#endif
|
||||
@@ -0,0 +1,354 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// define B = merge(N0, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t N1,
|
||||
index_t N2,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_E_N1_B_N2,
|
||||
class InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
class InBlockCopyThreadClusterArrangeOrder,
|
||||
class InBlockCopySrcAccessOrder,
|
||||
class InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopySrcDataPerRead_B,
|
||||
index_t InBlockCopyDstDataPerWrite_N2,
|
||||
class WeiBlockCopySubLengths_E_K,
|
||||
class WeiBlockCopyClusterLengths_E_K,
|
||||
class WeiBlockCopyThreadClusterArrangeOrder,
|
||||
class WeiBlockCopySrcAccessOrder,
|
||||
class WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
static_assert(N2 == GemmNPerThreadSubC, "wrong!");
|
||||
static_assert((N1 * N2 * BPerBlock) %
|
||||
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
|
||||
|
||||
constexpr index_t N0 = N / (N1 * N2);
|
||||
|
||||
constexpr index_t B = N0 * Ho * Wo;
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
|
||||
constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Ho>{})
|
||||
.Slice(I3, Number<Wo>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{})
|
||||
.Extract(Sequence<0, 1, 2, 4, 5>{});
|
||||
|
||||
// batch descritpor for device memory
|
||||
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{})
|
||||
.Slice(I3, Number<X>{})
|
||||
.Extract(Sequence<1, 2, 3>{});
|
||||
|
||||
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
|
||||
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc),
|
||||
Sequence<0, 1, 2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<3, 6, 7>{},
|
||||
Sequence<5>{});
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_n0_n1_n2_h_w_global_desc,
|
||||
"in_n0_n1_n2_h_w_global_desc: ");
|
||||
print_ConstantTensorDescriptor(in_c_y_x_global_desc, "in_c_y_x_global_desc: ");
|
||||
print_ConstantMergedTensorDescriptor(in_e_n1_b_n2_global_merged_desc,
|
||||
"in_e_n1_b_n2_global_merged_desc: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not satisfied");
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
Float,
|
||||
decltype(in_e_n1_b_n2_global_merged_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2>(
|
||||
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[EPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
|
||||
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<EPerBlock>{}, Number<KPerBlock>{}, Number<wei_e_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_e_n1bn2_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<EPerBlock>{},
|
||||
Number<N1 * BPerBlock * N2>{},
|
||||
Number<in_e_n1_b_n2_block_desc.GetStride(I0)>{});
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_e_k_block_mtx_desc),
|
||||
decltype(b_e_n1bn2_block_mtx_desc),
|
||||
decltype(c_k0k2_n1n2_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// choose GEMM implementation here
|
||||
const auto run_blockwise_gemm = [&](auto... Xs) {
|
||||
#if 1
|
||||
return blockwise_gemm.Run(Xs...);
|
||||
#else
|
||||
return blockwise_gemm.Run_asm(Xs...);
|
||||
#endif
|
||||
};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
in_e_n1_b_n2_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
constexpr index_t wei_block_space = wei_e_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
|
||||
|
||||
// do work
|
||||
for(index_t e = 0; e < E; e += EPerBlock)
|
||||
{
|
||||
// marching slicing window
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
run_blockwise_gemm(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t K0 = K / (K1 * K2);
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register
|
||||
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc =
|
||||
make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
|
||||
|
||||
// output tensor descriptor in register, src of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc =
|
||||
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old(
|
||||
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
|
||||
|
||||
// output memory layout descriptor in device memory, dst of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
// output merged global tensor descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
|
||||
Sequence<3>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 4, 5>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// origin of dst in device memory
|
||||
Float* p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
|
||||
|
||||
threadwise_generic_tensor_slice_copy_v1(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
|
||||
p_out_thread,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc,
|
||||
p_out_thread_on_global,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
|
||||
arithmetic_sequence_gen<0, 8, 1>::SeqType{},
|
||||
Number<1>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,415 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
|
||||
#ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
// define B = merge(N0, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t N1,
|
||||
index_t N2,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_E_N1_B_N2,
|
||||
class InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
class InBlockCopyThreadClusterArrangeOrder,
|
||||
class InBlockCopySrcAccessOrder,
|
||||
class InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopySrcDataPerRead_B,
|
||||
index_t InBlockCopyDstDataPerWrite_N2,
|
||||
class WeiBlockCopySubLengths_E_K,
|
||||
class WeiBlockCopyClusterLengths_E_K,
|
||||
class WeiBlockCopyThreadClusterArrangeOrder,
|
||||
class WeiBlockCopySrcAccessOrder,
|
||||
class WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
static_assert(N2 == GemmNPerThreadSubC, "wrong!");
|
||||
static_assert((N1 * N2 * BPerBlock) %
|
||||
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread");
|
||||
|
||||
constexpr index_t N0 = N / (N1 * N2);
|
||||
|
||||
constexpr index_t B = N0 * Ho * Wo;
|
||||
|
||||
constexpr index_t E = C * Y * X;
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
|
||||
constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Ho>{})
|
||||
.Slice(I3, Number<Wo>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{})
|
||||
.Extract(Sequence<0, 1, 2, 4, 5>{});
|
||||
|
||||
// batch descritpor for device memory
|
||||
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{})
|
||||
.Slice(I3, Number<X>{})
|
||||
.Extract(Sequence<1, 2, 3>{});
|
||||
|
||||
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
|
||||
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc),
|
||||
Sequence<0, 1, 2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<3, 6, 7>{},
|
||||
Sequence<5>{});
|
||||
|
||||
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not satisfied");
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
Float,
|
||||
decltype(in_e_n1_b_n2_global_merged_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2>(
|
||||
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[EPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
|
||||
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<EPerBlock>{}, Number<KPerBlock>{}, Number<wei_e_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_e_n1bn2_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<EPerBlock>{},
|
||||
Number<N1 * BPerBlock * N2>{},
|
||||
Number<in_e_n1_b_n2_block_desc.GetStride(I0)>{});
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t GemmMRepeat =
|
||||
KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster);
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_e_k_block_mtx_desc),
|
||||
decltype(b_e_n1bn2_block_mtx_desc),
|
||||
decltype(c_k0k2_n1n2_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// choose GEMM implementation here
|
||||
const auto run_blockwise_gemm = [&](auto... Xs) {
|
||||
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
return blockwise_gemm.Run_asm(Xs...);
|
||||
#else
|
||||
return blockwise_gemm.Run(Xs...);
|
||||
#endif
|
||||
};
|
||||
|
||||
// LDS allocation for input and weight: be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
in_e_n1_b_n2_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
constexpr index_t wei_block_space = wei_e_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
|
||||
|
||||
const Float* p_wei_block_on_global = p_wei_global;
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block_double);
|
||||
blockwise_wei_copy.Run(p_wei_global, p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E;
|
||||
e_block_data_begin += 2 * EPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_gemm(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
// even iteration
|
||||
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global, p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_block_on_global,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_gemm(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
run_blockwise_gemm(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t K0 = K / (K1 * K2);
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register
|
||||
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc =
|
||||
make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
|
||||
|
||||
// output tensor descriptor in register, src of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc =
|
||||
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old(
|
||||
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
|
||||
|
||||
// output memory layout descriptor in device memory, dst of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
// output merged global tensor descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
|
||||
Sequence<3>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 4, 5>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// origin of dst in device memory
|
||||
Float* p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
|
||||
|
||||
threadwise_generic_tensor_slice_copy_v1(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc,
|
||||
p_out_thread,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc,
|
||||
p_out_thread_on_global,
|
||||
{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(),
|
||||
arithmetic_sequence_gen<0, 8, 1>::SeqType{},
|
||||
Number<1>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,259 @@
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_direct_convolution.hpp"
|
||||
#include "threadwise_4d_tensor_op.hpp"
|
||||
#include "threadwise_direct_convolution.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class TInWei,
|
||||
class TOut,
|
||||
class TAccum,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t ScalarPerVector,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t CPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
index_t BlockSize,
|
||||
index_t GridSize>
|
||||
__global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
const typename vector_type<TInWei,
|
||||
ScalarPerVector>::MemoryType* const __restrict__ p_in_vec_global,
|
||||
const typename vector_type<TInWei,
|
||||
ScalarPerVector>::MemoryType* const __restrict__ p_wei_vec_global,
|
||||
TOut* const __restrict__ p_out_global)
|
||||
{
|
||||
using in_scalar_t = TInWei;
|
||||
using in_vector_mem_t = typename vector_type<in_scalar_t, ScalarPerVector>::MemoryType;
|
||||
using out_scalar_t = TOut;
|
||||
using accum_t = TAccum;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_vec_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_kcyx_vec_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_nchw_vec_global_desc.GetLength(I0);
|
||||
constexpr index_t K = wei_kcyx_vec_global_desc.GetLength(I0);
|
||||
constexpr index_t C = wei_kcyx_vec_global_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_kcyx_vec_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_kcyx_vec_global_desc.GetLength(I3);
|
||||
|
||||
constexpr auto wei_ke_vec_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
constexpr auto in_nchw_vec_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_ke_vec_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<KPerBlock, CPerBlock * Y * X>{},
|
||||
Number<WeiBlockCopyDataPerRead>{}); // 2d view of wei for blockwise copy
|
||||
|
||||
constexpr auto wei_kcyx_vec_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, Y, X>{},
|
||||
Sequence<wei_ke_vec_block_desc.GetStride(I0), Y * X, X, 1>{});
|
||||
|
||||
// shared mem
|
||||
constexpr index_t in_block_element_size =
|
||||
in_nchw_vec_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr index_t wei_block_element_size =
|
||||
wei_kcyx_vec_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
|
||||
__shared__ in_vector_mem_t
|
||||
p_in_vec_block[max_align * ((in_block_element_size + max_align - 1) / max_align)];
|
||||
__shared__ in_vector_mem_t
|
||||
p_wei_vec_block[max_align * ((wei_block_element_size + max_align - 1) / max_align)];
|
||||
|
||||
// threadwise tensors
|
||||
constexpr index_t HiPerThread = HoPerThread + Y - 1;
|
||||
constexpr index_t WiPerThread = WoPerThread + X - 1;
|
||||
|
||||
constexpr auto in_nchw_vec_thread_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
|
||||
in_nchw_vec_block_desc.GetStrides());
|
||||
|
||||
constexpr auto wei_kcyx_vec_thread_block_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, CPerThread, Y, X>{}, wei_kcyx_vec_block_desc.GetStrides());
|
||||
|
||||
constexpr auto out_nkhw_thread_desc = get_convolution_output_default_4d_tensor_descriptor(
|
||||
in_nchw_vec_thread_block_desc, wei_kcyx_vec_thread_block_desc);
|
||||
|
||||
// register
|
||||
out_scalar_t p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
|
||||
|
||||
// divide block work
|
||||
constexpr index_t NBlockWork = (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
|
||||
constexpr index_t KBlockWork = (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork =
|
||||
(out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
|
||||
|
||||
const index_t block_id = blockIdx.x;
|
||||
|
||||
index_t itmp = block_id;
|
||||
const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
|
||||
itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
|
||||
const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
|
||||
itmp -= k_block_work_id * (HBlockWork * WBlockWork);
|
||||
const index_t h_block_work_id = itmp / WBlockWork;
|
||||
const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
|
||||
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin; // minus padding
|
||||
const index_t wi_block_data_begin = wo_block_data_begin; // minus padding
|
||||
|
||||
// divide thread work
|
||||
constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
|
||||
constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread;
|
||||
constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
|
||||
constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
|
||||
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
|
||||
itmp = thread_id;
|
||||
const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
|
||||
itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork);
|
||||
const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork);
|
||||
itmp -= k_thread_work_id * (HThreadWork * WThreadWork);
|
||||
const index_t h_thread_work_id = itmp / WThreadWork;
|
||||
const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork;
|
||||
|
||||
const index_t n_thread_data_begin = n_thread_work_id * NPerThread;
|
||||
const index_t k_thread_data_begin = k_thread_work_id * KPerThread;
|
||||
const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread;
|
||||
const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread;
|
||||
|
||||
const index_t hi_thread_data_begin = ho_thread_data_begin;
|
||||
const index_t wi_thread_data_begin = wo_thread_data_begin;
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
in_vector_mem_t,
|
||||
decltype(in_nchw_vec_global_desc),
|
||||
decltype(in_nchw_vec_block_desc),
|
||||
decltype(in_nchw_vec_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead>{};
|
||||
|
||||
#if 0
|
||||
constexpr auto blockwise_wei_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
in_vector_mem_t,
|
||||
decltype(wei_kcyx_vec_global_desc),
|
||||
decltype(wei_kcyx_vec_block_desc),
|
||||
decltype(wei_kcyx_vec_block_desc.GetLengths()),
|
||||
1>{};
|
||||
#elif 1
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
in_vector_mem_t,
|
||||
decltype(wei_ke_vec_global_desc),
|
||||
decltype(wei_ke_vec_block_desc),
|
||||
decltype(wei_ke_vec_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
#if 1 // debug
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread);
|
||||
#endif
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
blockwise_in_copy.Run(
|
||||
p_in_vec_global +
|
||||
in_nchw_vec_global_desc.GetOffsetFromMultiIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
p_in_vec_block);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_wei_copy.Run(p_wei_vec_global +
|
||||
wei_kcyx_vec_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_vec_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(index_t c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
|
||||
{
|
||||
// threadwise convolution
|
||||
#if 1
|
||||
threadwise_direct_convolution_2(
|
||||
in_nchw_vec_thread_block_desc,
|
||||
p_in_vec_block +
|
||||
in_nchw_vec_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_vec_thread_block_desc,
|
||||
p_wei_vec_block +
|
||||
wei_kcyx_vec_block_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_begin, c_thread_data, 0, 0),
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread);
|
||||
#elif 0
|
||||
threadwise_direct_convolution_3(
|
||||
in_nchw_vec_thread_block_desc,
|
||||
p_in_vec_block +
|
||||
in_nchw_vec_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_vec_thread_block_desc,
|
||||
p_wei_vec_block +
|
||||
wei_kcyx_vec_block_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_begin, c_thread_data, 0, 0),
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// copy output tensor from register to global mem
|
||||
threadwise_4d_tensor_copy(out_nkhw_thread_desc,
|
||||
p_out_thread,
|
||||
out_nkhw_global_desc,
|
||||
p_out_global +
|
||||
out_nkhw_global_desc.GetOffsetFromMultiIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_nkhw_thread_desc.GetLengths());
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,298 @@
|
||||
#pragma once
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "threadwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
class LowerPads,
|
||||
class UpperPads,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t CPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t WeiBlockCopyThreadPerDim0,
|
||||
index_t WeiBlockCopyThreadPerDim1>
|
||||
__global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
|
||||
const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global)
|
||||
{
|
||||
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
|
||||
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
|
||||
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
|
||||
static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0");
|
||||
static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_chwn_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
|
||||
|
||||
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_khwn_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t HPadLow = LowerPads{}.Get(I0);
|
||||
constexpr index_t WPadLow = LowerPads{}.Get(I1);
|
||||
|
||||
constexpr index_t HPadUp = UpperPads{}.Get(I0);
|
||||
constexpr index_t WPadUp = UpperPads{}.Get(I1);
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
|
||||
const index_t w_block_work_id = itmp / NBlockWork;
|
||||
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
|
||||
// flattened (2d) tensor view of wei in global mem
|
||||
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
|
||||
|
||||
// tensor view of blockwise input and weight in LDS
|
||||
constexpr auto in_chwn_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{});
|
||||
|
||||
constexpr auto wei_cyxk_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock, Y, X, KPerBlock>{});
|
||||
|
||||
// flattened (2d) tensor view of wei in LDS
|
||||
constexpr auto wei_ek_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock * Y * X, KPerBlock>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_hkwn_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<HoPerThread, KPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_cyxk_block_desc, "wei_cyxk_block_desc");
|
||||
print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
const index_t h_block_pad_low = h_block_work_id == 0 ? HPadLow : 0;
|
||||
const index_t w_block_pad_low = w_block_work_id == 0 ? WPadLow : 0;
|
||||
|
||||
const index_t h_block_pad_up = h_block_work_id == HBlockWork - 1 ? HPadUp : 0;
|
||||
const index_t w_block_pad_up = w_block_work_id == WBlockWork - 1 ? WPadUp : 0;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
;
|
||||
{
|
||||
printf(
|
||||
"%u %u, h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
h_block_pad_low,
|
||||
w_block_pad_low,
|
||||
h_block_pad_up,
|
||||
w_block_pad_up);
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
BlockwiseChwnTensorCopyPadded<BlockSize,
|
||||
Float,
|
||||
decltype(in_chwn_global_desc),
|
||||
decltype(in_chwn_block_desc),
|
||||
decltype(in_chwn_block_desc.GetLengths()),
|
||||
LowerPads>{};
|
||||
|
||||
#if 0
|
||||
// weight: format is [C,Y,X,K]
|
||||
constexpr auto blockwise_wei_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_cyxk_global_desc),
|
||||
decltype(wei_cyxk_block_desc),
|
||||
decltype(wei_cyxk_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
// weight: format is [C*Y*X,K]
|
||||
constexpr auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths())>{};
|
||||
#elif 1
|
||||
// weight: format is [C*Y*X,K]
|
||||
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
#endif
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N]
|
||||
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_cyxk_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_cxwn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_chwn_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_kxwn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{}, Number<WoPerThread * NPerThread>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
Blockwise1dStridedBatchedGemmBlockABlockBThreadC<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(I0),
|
||||
HoPerBlock,
|
||||
HoPerThread,
|
||||
CPerThread,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
constexpr index_t in_block_element_size = in_chwn_block_desc.GetElementSpace();
|
||||
constexpr index_t wei_block_element_size = wei_cyxk_block_desc.GetElementSpace();
|
||||
|
||||
__shared__ Float p_in_block[in_block_element_size];
|
||||
__shared__ Float p_wei_block[wei_block_element_size];
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()];
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread);
|
||||
|
||||
const Float* p_wei_global_block_begin =
|
||||
p_wei_global + wei_ek_global_desc.GetOffsetFromMultiIndex(0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
{
|
||||
#if 1
|
||||
// input: global mem to LDS,
|
||||
blockwise_in_copy.Run(p_in_global,
|
||||
c_block_data_begin,
|
||||
ho_block_data_begin,
|
||||
wo_block_data_begin,
|
||||
n_block_data_begin,
|
||||
p_in_block,
|
||||
h_block_pad_low,
|
||||
w_block_pad_low,
|
||||
h_block_pad_up,
|
||||
w_block_pad_up);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// a series of batched GEMM
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
|
||||
blockwise_batch_gemm.Run(
|
||||
p_wei_block + wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_in_block + in_chwn_block_desc.GetOffsetFromMultiIndex(0, y, x, 0),
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto matrix_c_index =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t ho_thread_data_begin = matrix_c_index.batch;
|
||||
const index_t k_thread_data_begin = matrix_c_index.row;
|
||||
const index_t wo_thread_data_begin = matrix_c_index.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = matrix_c_index.col - wo_thread_data_begin * NPerBlock;
|
||||
|
||||
#if 0
|
||||
printf("block %u %u, %u %u %u %u, %u %u %u %u, %f \n",
|
||||
get_block_1d_id(), get_thread_local_1d_id(),
|
||||
ho_block_data_begin, k_block_data_begin, wo_block_data_begin, n_block_data_begin,
|
||||
ho_thread_data_begin, k_thread_data_begin, wo_thread_data_begin, n_thread_data_begin,
|
||||
p_out_thread[0]);
|
||||
#endif
|
||||
|
||||
// output: register to global mem,
|
||||
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
|
||||
constexpr auto reorder_khwn_from_hkwn = Sequence<1, 0, 2, 3>{};
|
||||
|
||||
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
|
||||
out_hkwn_thread_desc,
|
||||
p_out_thread,
|
||||
out_khwn_global_desc,
|
||||
p_out_global +
|
||||
out_khwn_global_desc.GetOffsetFromMultiIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_hkwn_thread_desc.GetLengths(),
|
||||
reorder_khwn_from_hkwn);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,67 @@
|
||||
#ifndef CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
|
||||
#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t NRow_, index_t NCol_, index_t RowStride_>
|
||||
struct ConstantMatrixDescriptor
|
||||
{
|
||||
__host__ __device__ constexpr ConstantMatrixDescriptor()
|
||||
{
|
||||
static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t NRow() { return NRow_; }
|
||||
|
||||
__host__ __device__ static constexpr index_t NCol() { return NCol_; }
|
||||
|
||||
__host__ __device__ static constexpr index_t RowStride() { return RowStride_; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths() { return Sequence<NRow_, NCol_>{}; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSize() { return NRow_ * NCol_; }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSpace() { return NRow_ * RowStride_; }
|
||||
|
||||
__host__ __device__ static index_t GetOffsetFromMultiIndex(index_t irow, index_t icol)
|
||||
{
|
||||
return irow * RowStride_ + icol;
|
||||
}
|
||||
|
||||
template <index_t SubNRow, index_t SubNCol>
|
||||
__host__ __device__ static constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
|
||||
Number<SubNCol>)
|
||||
{
|
||||
return ConstantMatrixDescriptor<SubNRow, SubNCol, RowStride_>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NRow, index_t NCol>
|
||||
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>)
|
||||
{
|
||||
return ConstantMatrixDescriptor<NRow, NCol, NCol>{};
|
||||
}
|
||||
|
||||
template <index_t NRow, index_t NCol, index_t RowStride>
|
||||
__host__ __device__ constexpr auto
|
||||
make_ConstantMatrixDescriptor(Number<NRow>, Number<NCol>, Number<RowStride>)
|
||||
{
|
||||
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
|
||||
}
|
||||
|
||||
template <class TDesc>
|
||||
__host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s)
|
||||
{
|
||||
const auto desc = TDesc{};
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
printf("%s NRow %u NCol %u RowStride %u\n", s, desc.NRow(), desc.NCol(), desc.RowStride());
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,193 @@
|
||||
#ifndef CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_HPP
|
||||
#define CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// OriginalTensorDesc : ConstantTensorDescriptor<...>
|
||||
// it's the tensor whose dimensions are to be merged
|
||||
// OriginalDimMergeSeqs : Sequence<...>...
|
||||
// each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
|
||||
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
|
||||
struct ConstantMergedTensorDescriptor
|
||||
{
|
||||
using Type = ConstantMergedTensorDescriptor;
|
||||
|
||||
static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};
|
||||
|
||||
static constexpr index_t nDim = sizeof...(OriginalDimMergeSeqs);
|
||||
static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension();
|
||||
|
||||
__host__ __device__ constexpr ConstantMergedTensorDescriptor()
|
||||
{
|
||||
static_assert(nDim <= nOriginalDim, "wrong!");
|
||||
|
||||
// TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most
|
||||
// OriginalTensorDesc::nDim number of dimensions
|
||||
|
||||
// TODO: check OriginalDimMergeSeqs contains all original dimensions
|
||||
|
||||
// TODO: check there is no duplication in OriginalDimMergeSeqs
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetOriginalTensorDescriptor()
|
||||
{
|
||||
return OriginalTensorDesc{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
|
||||
{
|
||||
return std::get<IDim>(mOriginalDimMergeSeqs);
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number<IDim>)
|
||||
{
|
||||
return (std::get<IDim>(mOriginalDimMergeSeqs).GetSize() > 1);
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr index_t GetLength(Number<IDim>)
|
||||
{
|
||||
constexpr auto original_dims_partial = std::get<IDim>(mOriginalDimMergeSeqs);
|
||||
|
||||
return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize();
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr index_t GetStride(Number<IDim>)
|
||||
{
|
||||
static_assert(!ContainMultipleOriginalDimensions(Number<IDim>{}),
|
||||
"wrong! stride of a merged dimension is undefined");
|
||||
|
||||
constexpr auto idim_original = std::get<IDim>(mOriginalDimMergeSeqs).Front();
|
||||
|
||||
return OriginalTensorDesc::GetStride(Number<idim_original>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths()
|
||||
{
|
||||
return Sequence<OriginalTensorDesc::Extract(OriginalDimMergeSeqs{}).GetElementSize()...>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSize()
|
||||
{
|
||||
return OriginalTensorDesc::GetElementSize();
|
||||
}
|
||||
|
||||
template <class OriginalDimsPartial>
|
||||
struct lambda_1_GetOriginalMultiIndexFromMultiIndex
|
||||
{
|
||||
const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial;
|
||||
Array<index_t, nOriginalDim>& original_multi_id;
|
||||
|
||||
__host__ __device__ constexpr lambda_1_GetOriginalMultiIndexFromMultiIndex(
|
||||
const Array<index_t, OriginalDimsPartial::GetSize()>& original_multi_id_partial_,
|
||||
Array<index_t, nOriginalDim>& original_multi_id_)
|
||||
: original_multi_id_partial(original_multi_id_partial_),
|
||||
original_multi_id(original_multi_id_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr void operator()(Number<I>) const
|
||||
{
|
||||
constexpr index_t idim_original = OriginalDimsPartial::Get(Number<I>{});
|
||||
|
||||
index_t itmp = original_multi_id_partial[I];
|
||||
|
||||
original_multi_id.Set(Number<idim_original>{}, itmp);
|
||||
}
|
||||
};
|
||||
|
||||
struct lambda_0_GetOriginalMultiIndexFromMultiIndex
|
||||
{
|
||||
const Array<index_t, nDim>& multi_id;
|
||||
Array<index_t, nOriginalDim>& original_multi_id;
|
||||
|
||||
__host__ __device__ constexpr lambda_0_GetOriginalMultiIndexFromMultiIndex(
|
||||
const Array<index_t, nDim>& multi_id_, Array<index_t, nOriginalDim>& original_multi_id_)
|
||||
: multi_id(multi_id_), original_multi_id(original_multi_id_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr void operator()(Number<IDim>) const
|
||||
{
|
||||
constexpr auto original_dims_partial = std::get<IDim>(Type::mOriginalDimMergeSeqs);
|
||||
|
||||
// get partial original-multi-id corresponding to this merged dimension
|
||||
const auto original_multi_id_partial =
|
||||
OriginalTensorDesc::Extract(original_dims_partial)
|
||||
.GetMultiIndexFrom1dIndex(multi_id[IDim]);
|
||||
|
||||
static_for<0, original_dims_partial.GetSize(), 1>{}(
|
||||
lambda_1_GetOriginalMultiIndexFromMultiIndex<decltype(original_dims_partial)>(
|
||||
original_multi_id_partial, original_multi_id));
|
||||
}
|
||||
};
|
||||
|
||||
// return type is Array<...>
|
||||
__host__ __device__ static constexpr auto
|
||||
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
Array<index_t, nOriginalDim> original_multi_id;
|
||||
|
||||
static_for<0, nDim, 1>{}(
|
||||
lambda_0_GetOriginalMultiIndexFromMultiIndex(multi_id, original_multi_id));
|
||||
|
||||
return original_multi_id;
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>)
|
||||
{
|
||||
constexpr auto multi_id = sequence2array(Sequence<Is...>{});
|
||||
|
||||
constexpr auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
|
||||
|
||||
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
GetOffsetFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id);
|
||||
|
||||
return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id);
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
|
||||
{
|
||||
return GetOffsetFromMultiIndex(Array<index_t, nDim>{is...});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
|
||||
{
|
||||
constexpr auto packed_desc = make_ConstantTensorDescriptor_packed(GetLengths());
|
||||
|
||||
return packed_desc.GetMultiIndexFrom1dIndex(id);
|
||||
}
|
||||
};
|
||||
|
||||
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
|
||||
__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc,
|
||||
OriginalDimMergeSeqs...)
|
||||
{
|
||||
return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
|
||||
}
|
||||
|
||||
template <class TDesc>
|
||||
__host__ __device__ void print_ConstantMergedTensorDescriptor(const char* s, TDesc)
|
||||
{
|
||||
print_ConstantTensorDescriptor(s, TDesc::GetOriginalTensorDescriptor());
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,519 @@
|
||||
#ifndef CK_CONSTANT_TENSOR_DESCRIPTOR_HPP
|
||||
#define CK_CONSTANT_TENSOR_DESCRIPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(
|
||||
Lengths{}.PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
.PushBack(Number<1>{});
|
||||
}
|
||||
|
||||
template <class Lengths, index_t Align>
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_aligned(Lengths, Number<Align>)
|
||||
{
|
||||
constexpr index_t L_back_align =
|
||||
Align * math::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
|
||||
|
||||
return calculate_tensor_strides_packed(
|
||||
Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
|
||||
}
|
||||
|
||||
template <class Lengths, class Strides>
|
||||
struct ConstantTensorDescriptor
|
||||
{
|
||||
using Type = ConstantTensorDescriptor;
|
||||
|
||||
static constexpr index_t nDim = Lengths::GetSize();
|
||||
|
||||
__host__ __device__ constexpr ConstantTensorDescriptor()
|
||||
{
|
||||
static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetOriginalTensorDescriptor() { return Type{}; }
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number<IDim>)
|
||||
{
|
||||
return Sequence<IDim>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfDimension() { return nDim; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLengths() { return Lengths{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetStrides() { return Strides{}; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static constexpr index_t GetLength(Number<I>)
|
||||
{
|
||||
return Lengths::Get(Number<I>{});
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static constexpr index_t GetStride(Number<I>)
|
||||
{
|
||||
return Strides::Get(Number<I>{});
|
||||
}
|
||||
|
||||
struct lambda_AreDimensionsContinuous
|
||||
{
|
||||
bool& is_continuous;
|
||||
|
||||
__host__ __device__ constexpr lambda_AreDimensionsContinuous(bool& is_continuous_)
|
||||
: is_continuous(is_continuous_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IDim_>
|
||||
__host__ __device__ constexpr void operator()(Number<IDim_>) const
|
||||
{
|
||||
constexpr auto IDim = Number<IDim_>{};
|
||||
constexpr auto IDim_p1 = Number<IDim_ + 1>{};
|
||||
|
||||
is_continuous =
|
||||
is_continuous && (GetStride(IDim) >= GetStride(IDim_p1) &&
|
||||
GetStride(IDim) == GetStride(IDim_p1) * GetLength(IDim_p1));
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr bool AreDimensionsContinuous()
|
||||
{
|
||||
bool is_continuous = true;
|
||||
|
||||
static_for<0, nDim - 1, 1>{}(lambda_AreDimensionsContinuous(is_continuous));
|
||||
|
||||
return is_continuous;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsPackedTensor()
|
||||
{
|
||||
return AreDimensionsContinuous() && GetStride(Number<nDim - 1>{}) == 1;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(T)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetElementSize()
|
||||
{
|
||||
return accumulate_on_sequence(Lengths{}, math::multiplies<index_t>{}, Number<1>{});
|
||||
}
|
||||
|
||||
template <class Align = Number<1>>
|
||||
__host__ __device__ static constexpr index_t GetElementSpace(Align align = Align{})
|
||||
{
|
||||
// This is WRONG! align shouldbe applied to the last memory rank, not the last tensor
|
||||
// dimension
|
||||
constexpr index_t element_space_unaligned = accumulate_on_sequence(
|
||||
(GetLengths() - Number<1>{}) * GetStrides(), math::plus<index_t>{}, Number<1>{});
|
||||
|
||||
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
|
||||
}
|
||||
|
||||
// emulate constexpr lambda
|
||||
template <index_t NSize>
|
||||
struct lambda_GetOffsetFromMultiIndex
|
||||
{
|
||||
Array<index_t, NSize>& multi_id;
|
||||
index_t& offset;
|
||||
|
||||
__host__
|
||||
__device__ constexpr lambda_GetOffsetFromMultiIndex(Array<index_t, NSize>& multi_id_,
|
||||
index_t& offset_)
|
||||
: multi_id(multi_id_), offset(offset_)
|
||||
{
|
||||
}
|
||||
|
||||
template <class X>
|
||||
__host__ __device__ constexpr void operator()(X IDim) const
|
||||
{
|
||||
offset += multi_id[IDim] * Type::GetStride(IDim);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t NSize>
|
||||
__host__ __device__ static constexpr index_t
|
||||
GetOffsetFromMultiIndex(Array<index_t, NSize> multi_id)
|
||||
{
|
||||
static_assert(NSize == nDim, "wrong! Dimension not consistent");
|
||||
|
||||
index_t offset = 0;
|
||||
|
||||
static_for<0, nDim, 1>{}(lambda_GetOffsetFromMultiIndex<NSize>(multi_id, offset));
|
||||
|
||||
return offset;
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is)
|
||||
{
|
||||
return GetOffsetFromMultiIndex(Array<index_t, sizeof...(Is)>{is...});
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence<Is...>)
|
||||
{
|
||||
static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent");
|
||||
|
||||
constexpr auto multi_id = Sequence<Is...>{};
|
||||
|
||||
return accumulate_on_sequence(multi_id * GetStrides(), math::plus<index_t>{}, Number<0>{});
|
||||
}
|
||||
|
||||
// emulate constexpr lambda
|
||||
template <class PackedStrides>
|
||||
struct lambda_GetMultiIndexFrom1dIndex
|
||||
{
|
||||
index_t& id;
|
||||
Array<index_t, nDim>& multi_id;
|
||||
|
||||
__host__
|
||||
__device__ constexpr lambda_GetMultiIndexFrom1dIndex(index_t& id_,
|
||||
Array<index_t, nDim>& multi_id_)
|
||||
: id(id_), multi_id(multi_id_)
|
||||
{
|
||||
}
|
||||
|
||||
template <class IDim_>
|
||||
__host__ __device__ constexpr void operator()(IDim_) const
|
||||
{
|
||||
constexpr auto IDim = IDim_{};
|
||||
constexpr index_t stride = PackedStrides::Get(IDim);
|
||||
multi_id.Set(IDim, id / stride);
|
||||
id -= multi_id[IDim] * stride;
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr Array<index_t, nDim> GetMultiIndexFrom1dIndex(index_t id)
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
using PackedStrides = decltype(calculate_tensor_strides_packed(GetLengths()));
|
||||
|
||||
// calculate index in each of the dimensions in the order of their dimension
|
||||
static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex<PackedStrides>(id, multi_id));
|
||||
|
||||
multi_id.Set(Number<nDim - 1>{}, id / PackedStrides::Get(Number<nDim - 1>{}));
|
||||
|
||||
return multi_id;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetOriginalMultiIndexFromMultiIndex(Array<index_t, nDim> multi_id)
|
||||
{
|
||||
return multi_id;
|
||||
}
|
||||
|
||||
// This function doesn't do carry check on the highest dimension for positive stepping (or
|
||||
// borrow check on the lowest dimension for negative stepping) , for performance reason. It is
|
||||
// the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound on the
|
||||
// highest dimension for positive stepping (or on the lowest dimension for negative stepping)
|
||||
template <bool PositiveDirection>
|
||||
__host__ __device__ static Array<index_t, nDim>
|
||||
UpdateMultiIndexGivenStepSizeOf1dIndex(Array<index_t, nDim> old_multi_id,
|
||||
index_t step_size_of_1d_index,
|
||||
integral_constant<bool, PositiveDirection>)
|
||||
{
|
||||
Array<index_t, nDim> new_multi_id;
|
||||
|
||||
const auto step_sizes = GetMultiIndexFrom1dIndex(step_size_of_1d_index);
|
||||
|
||||
static_if<PositiveDirection>{}([&](auto) {
|
||||
new_multi_id = old_multi_id + step_sizes;
|
||||
|
||||
bool carry = false;
|
||||
|
||||
// do carry check in reversed order, starting from lowest dimension
|
||||
// don't check the highest dimension
|
||||
static_for<0, nDim, 1>{}([&](auto IDimReverse) {
|
||||
constexpr index_t idim = nDim - 1 - IDimReverse.Get();
|
||||
constexpr auto IDim = Number<idim>{};
|
||||
|
||||
if(carry)
|
||||
{
|
||||
++new_multi_id(idim);
|
||||
}
|
||||
|
||||
carry = false;
|
||||
|
||||
if(new_multi_id[idim] >= GetLength(IDim))
|
||||
{
|
||||
new_multi_id(idim) -= GetLength(IDim);
|
||||
carry = true;
|
||||
}
|
||||
});
|
||||
}).Else([&](auto) {
|
||||
// shift up multi-id to avoid unsigned integer underflow during intermediate
|
||||
// calculations. After the shift, should have new_multi_id[...] >= 1
|
||||
new_multi_id = old_multi_id + (GetLengths() - step_sizes);
|
||||
|
||||
bool borrow = false;
|
||||
|
||||
// do borrow check in reversed order, starting from lowest dimension
|
||||
// don't check the highest dimension
|
||||
static_for<0, nDim, 1>{}([&](auto IDimReverse) {
|
||||
constexpr index_t idim = nDim - 1 - IDimReverse.Get();
|
||||
constexpr auto IDim = Number<idim>{};
|
||||
|
||||
if(borrow)
|
||||
{
|
||||
--new_multi_id(idim);
|
||||
}
|
||||
|
||||
borrow = false;
|
||||
|
||||
if(new_multi_id[idim] < GetLength(IDim))
|
||||
{
|
||||
new_multi_id(idim) += GetLength(IDim);
|
||||
borrow = true;
|
||||
}
|
||||
});
|
||||
|
||||
// shift back down multi-id
|
||||
// here, should have new_multi_id[...] >= GetLengths()
|
||||
new_multi_id = new_multi_id - GetLengths();
|
||||
});
|
||||
|
||||
return new_multi_id;
|
||||
}
|
||||
|
||||
template <index_t... IDims>
|
||||
__host__ __device__ static constexpr auto Extract(Number<IDims>... extract_dims)
|
||||
{
|
||||
static_assert(sizeof...(IDims) <= GetNumOfDimension(),
|
||||
"wrong! too many number of dimensions to be extracted");
|
||||
|
||||
using extract_lengths = decltype(Lengths::Extract(extract_dims...));
|
||||
using extract_strides = decltype(Strides::Extract(extract_dims...));
|
||||
|
||||
return ConstantTensorDescriptor<extract_lengths, extract_strides>{};
|
||||
}
|
||||
|
||||
template <index_t... IDims>
|
||||
__host__ __device__ static constexpr auto Extract(Sequence<IDims...>)
|
||||
{
|
||||
return Extract(Number<IDims>{}...);
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto Embed(ConstantTensorDescriptor<Ts...>)
|
||||
{
|
||||
using leaf_tensor = ConstantTensorDescriptor<Ts...>;
|
||||
|
||||
return ConstantTensorDescriptor<decltype(GetLengths().Append(leaf_tensor::GetLengths())),
|
||||
decltype(GetStrides().Append(leaf_tensor::GetStrides()))>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t SliceLen>
|
||||
__host__ __device__ static constexpr auto Slice(Number<IDim>, Number<SliceLen>)
|
||||
{
|
||||
using slice_lengths = decltype(Lengths{}.Modify(Number<IDim>{}, Number<SliceLen>{}));
|
||||
|
||||
return ConstantTensorDescriptor<slice_lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... FoldIntervals>
|
||||
__host__ __device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...)
|
||||
{
|
||||
constexpr auto fold_intervals = Sequence<FoldIntervals...>{};
|
||||
|
||||
constexpr index_t fold_intervals_product =
|
||||
accumulate_on_sequence(fold_intervals, math::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
constexpr auto unfold_length = GetLength(Number<IDim>{});
|
||||
constexpr auto unfold_stride = GetStride(Number<IDim>{});
|
||||
|
||||
// length of the dimension to be folded needs to be dividable by fold_interval_product,
|
||||
// otherwise, folding is invalid
|
||||
static_assert(unfold_length % fold_intervals_product == 0,
|
||||
"wrong! length on the dimension to be folded cannot be evenly divided!");
|
||||
|
||||
// folded lengths
|
||||
constexpr auto fold_lengths =
|
||||
Sequence<unfold_length / fold_intervals_product>{}.Append(fold_intervals);
|
||||
|
||||
// folded strides
|
||||
constexpr auto fold_strides =
|
||||
Number<unfold_stride>{} *
|
||||
reverse_inclusive_scan_sequence(
|
||||
fold_intervals.PushBack(Number<1>{}), math::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
// left and right
|
||||
constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::SeqType{};
|
||||
constexpr auto right =
|
||||
typename arithmetic_sequence_gen<IDim + 1, GetNumOfDimension(), 1>::SeqType{};
|
||||
|
||||
constexpr auto new_lengths =
|
||||
GetLengths().Extract(left).Append(fold_lengths).Append(GetLengths().Extract(right));
|
||||
constexpr auto new_strides =
|
||||
GetStrides().Extract(left).Append(fold_strides).Append(GetStrides().Extract(right));
|
||||
|
||||
return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
|
||||
}
|
||||
|
||||
// this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension
|
||||
template <index_t FirstUnfoldDim, index_t LastUnfoldDim>
|
||||
__host__ __device__ static constexpr auto Unfold(Number<FirstUnfoldDim>, Number<LastUnfoldDim>)
|
||||
{
|
||||
static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim &&
|
||||
FirstUnfoldDim <= LastUnfoldDim,
|
||||
"wrong! should have FirstUnfoldDim <= LastUnfoldDim!");
|
||||
|
||||
// left and right
|
||||
constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::SeqType{};
|
||||
constexpr auto middle =
|
||||
typename arithmetic_sequence_gen<FirstUnfoldDim, LastUnfoldDim + 1, 1>::SeqType{};
|
||||
constexpr auto right =
|
||||
typename arithmetic_sequence_gen<LastUnfoldDim + 1, GetNumOfDimension(), 1>::SeqType{};
|
||||
|
||||
// dimensions to be unfolded need to be continuous
|
||||
static_assert(Type::Extract(middle).AreDimensionsContinuous(), "wrong! not unfoldable");
|
||||
|
||||
// unfolded length, stride
|
||||
constexpr index_t unfold_length = accumulate_on_sequence(
|
||||
GetLengths().Extract(middle), math::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
constexpr index_t unfold_stride = GetStride(Number<LastUnfoldDim>{});
|
||||
|
||||
// new lengths, strides
|
||||
constexpr auto new_lengths = GetLengths()
|
||||
.Extract(left)
|
||||
.PushBack(Number<unfold_length>{})
|
||||
.Append(GetLengths().Extract(right));
|
||||
|
||||
constexpr auto new_strides = GetStrides()
|
||||
.Extract(left)
|
||||
.PushBack(Number<unfold_stride>{})
|
||||
.Append(GetStrides().Extract(right));
|
||||
|
||||
return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
|
||||
}
|
||||
|
||||
template <class MapNew2Old>
|
||||
__host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old)
|
||||
{
|
||||
return ConstantTensorDescriptor<decltype(Lengths::ReorderGivenNew2Old(MapNew2Old{})),
|
||||
decltype(Strides::ReorderGivenNew2Old(MapNew2Old{}))>{};
|
||||
}
|
||||
|
||||
#if 0 // require sequence_sort, which is not implemented yet
|
||||
template <class MapOld2New>
|
||||
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
|
||||
{
|
||||
return ConstantTensorDescriptor<decltype(Lengths::ReorderGivenOld2New(MapOld2New{})),
|
||||
decltype(Strides::ReorderGivenOld2New(MapOld2New{}))>{}
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_packed(Lengths)
|
||||
{
|
||||
using Strides = decltype(calculate_tensor_strides_packed(Lengths{}));
|
||||
return ConstantTensorDescriptor<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <class Lengths, class Strides>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides)
|
||||
{
|
||||
return ConstantTensorDescriptor<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <class Lengths, index_t Align>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
|
||||
{
|
||||
using Strides = decltype(calculate_tensor_strides_aligned(Lengths{}, Number<Align>{}));
|
||||
return ConstantTensorDescriptor<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <index_t... Lengths, index_t... Strides>
|
||||
__host__ __device__ void
|
||||
print_ConstantTensorDescriptor(const char* s,
|
||||
ConstantTensorDescriptor<Sequence<Lengths...>, Sequence<Strides...>>)
|
||||
{
|
||||
constexpr index_t ndim = sizeof...(Lengths);
|
||||
|
||||
static_assert(ndim > 0 && ndim <= 10, "wrong!");
|
||||
|
||||
static_if<ndim == 1>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u}, strides {%u}\n", s, ndim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 2>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, ndim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 3>{}([&](auto) {
|
||||
printf(
|
||||
"%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, ndim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 4>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 5>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 6>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 7>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 8>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 9>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
|
||||
"%u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<ndim == 10>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
ndim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,806 @@
|
||||
#ifndef CK_BLOCKWISE_2D_TENSOR_OP_HPP
|
||||
#define CK_BLOCKWISE_2D_TENSOR_OP_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize, class Float, class DstDesc, class F>
|
||||
__device__ void
|
||||
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths());
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
|
||||
print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
const index_t did0 = is / desc.GetStride(I0);
|
||||
|
||||
is -= did0 * desc.GetStride(I0);
|
||||
|
||||
const index_t did1 = is / desc.GetStride(I1);
|
||||
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
f(p_dst[dindex]);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < desc.GetElementSize())
|
||||
{
|
||||
const index_t did0 = is / desc.GetStride(I0);
|
||||
|
||||
is -= did0 * desc.GetStride(I0);
|
||||
|
||||
const index_t did1 = is / desc.GetStride(I1);
|
||||
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
f(p_dst[dindex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Function: p_dst[reorder[i0], reorder[i1] = p_src[i0,i1]
|
||||
// TODO: in order to optimize mem access for different mem type,
|
||||
// need to write specialized version
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src,
|
||||
class F>
|
||||
__device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src,
|
||||
F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t IR0 = MapDst2Src{}.Get(I0);
|
||||
constexpr index_t IR1 = MapDst2Src{}.Get(I1);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
index_t did[2];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
const index_t aindex = src_desc.GetOffsetFromMultiIndex(did[0], did[1]);
|
||||
|
||||
const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1]);
|
||||
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
index_t did[2];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
const index_t aindex = src_desc.GetOffsetFromMultiIndex(did[0], did[1]);
|
||||
|
||||
const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1]);
|
||||
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t BlockSize, class Float, class DstDesc>
|
||||
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
|
||||
{
|
||||
auto f_set_zero = [](Float& v) { v = Float(0); };
|
||||
|
||||
blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
{
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
|
||||
blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
index_t DataPerRead>
|
||||
struct Blockwise2dTensorCopy1
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
__device__ constexpr Blockwise2dTensorCopy1()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
static_assert(DataPerRead == 1 ||
|
||||
(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
|
||||
"wrong! only support stride1 == 1 if DataPerRead > 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I0) % DataPerRead == 0,
|
||||
"src and dst stride2 should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
// we allow out-of-bound read from src in D1 dimension,
|
||||
// but we need to make sure dst stride0 is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t read_per_d1 = math::integer_divide_ceil(L1, DataPerRead);
|
||||
|
||||
static_assert(read_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t read_per_d1 = math::integer_divide_ceil(L1, DataPerRead);
|
||||
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, read_per_d1>{});
|
||||
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
auto f_copy = [&](index_t is) {
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
const index_t src_index =
|
||||
src_desc.GetOffsetFromMultiIndex(did[0], did[1] * DataPerRead);
|
||||
const index_t dst_index =
|
||||
dst_desc.GetOffsetFromMultiIndex(did[0], did[1] * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + src_index));
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
f_copy(is);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
f_copy(is);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// need to be aligned to float4 and float2
|
||||
// stride1 need to be 1 for both source and destination
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
index_t ThreadPerDim0,
|
||||
index_t ThreadPerDim1>
|
||||
struct Blockwise2dTensorCopy2
|
||||
{
|
||||
index_t mThreadId0;
|
||||
index_t mThreadId1;
|
||||
|
||||
__device__ Blockwise2dTensorCopy2()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1,
|
||||
"wrong! stride is not 1!\n");
|
||||
|
||||
mThreadId0 = get_thread_local_1d_id() / ThreadPerDim1;
|
||||
mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1;
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
|
||||
|
||||
using Float4 = float4;
|
||||
using Float2 = float2;
|
||||
|
||||
if(get_thread_local_1d_id() >= ThreadPerDim0 * ThreadPerDim1)
|
||||
return;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
// check alignment
|
||||
constexpr bool align_v4 =
|
||||
src_desc.GetStride(I0) % 4 == 0 && dst_desc.GetStride(I0) % 4 == 0;
|
||||
|
||||
constexpr bool align_v2 =
|
||||
src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;
|
||||
|
||||
constexpr index_t L0 = SrcOpLengths{}.Get(I0);
|
||||
constexpr index_t L1 = SrcOpLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t Dim0Loop = L0 / ThreadPerDim0;
|
||||
constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
|
||||
|
||||
constexpr index_t Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
|
||||
|
||||
constexpr index_t Dim1V2Loop =
|
||||
align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;
|
||||
|
||||
constexpr index_t Dim1V1Loop =
|
||||
(L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
|
||||
ThreadPerDim1;
|
||||
|
||||
constexpr bool d1_has_tail =
|
||||
(L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));
|
||||
|
||||
for(index_t d0loop = 0; d0loop < Dim0Loop; ++d0loop)
|
||||
{
|
||||
index_t did0 = d0loop * ThreadPerDim0 + mThreadId0;
|
||||
|
||||
// v4
|
||||
for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
|
||||
{
|
||||
index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
|
||||
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<Float4*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<const Float4*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v2
|
||||
for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
|
||||
{
|
||||
index_t did1 =
|
||||
Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;
|
||||
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<Float2*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<const Float2*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v1
|
||||
for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
|
||||
{
|
||||
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
d1v1loop * ThreadPerDim1 + mThreadId1;
|
||||
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
|
||||
// dim-1 tail
|
||||
if(d1_has_tail)
|
||||
{
|
||||
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
|
||||
|
||||
if(did1 < L1)
|
||||
{
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dim-0 tail
|
||||
if(d0_has_tail)
|
||||
{
|
||||
index_t did0 = Dim0Loop * ThreadPerDim0 + mThreadId0;
|
||||
|
||||
if(did0 < L0)
|
||||
{
|
||||
|
||||
// v4
|
||||
for(index_t d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
|
||||
{
|
||||
index_t did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
|
||||
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<Float4*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<const Float4*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v2
|
||||
for(index_t d1v2loop = 0; d1v2loop < Dim1V2Loop; ++d1v2loop)
|
||||
{
|
||||
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
|
||||
2 * mThreadId1;
|
||||
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<Float2*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<const Float2*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v1
|
||||
for(index_t d1v1loop = 0; d1v1loop < Dim1V1Loop; ++d1v1loop)
|
||||
{
|
||||
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
d1v1loop * ThreadPerDim1 + mThreadId1;
|
||||
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
|
||||
// tail
|
||||
if(d1_has_tail)
|
||||
{
|
||||
index_t did1 = Dim1V4Loop * 4 * ThreadPerDim1 + Dim1V2Loop * 2 * ThreadPerDim1 +
|
||||
Dim1V1Loop * ThreadPerDim1 + mThreadId1;
|
||||
|
||||
if(did1 < L1)
|
||||
{
|
||||
const index_t sindex = src_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// starting point need to be aligned to float4 or float2 or float
|
||||
// stride1 need to be 1 for both source and destination
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
index_t DataPerRead>
|
||||
struct Blockwise2dTensorCopy3
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__ Blockwise2dTensorCopy3(Array<index_t, 2> src_block_data_multi_id_begin,
|
||||
Array<index_t, 2> dst_block_data_multi_id_begin)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
static_assert(DataPerRead == 1 ||
|
||||
(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
|
||||
"wrong! only support stride1 == 1 if DataPerRead > 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I0) % DataPerRead == 0,
|
||||
"src and dst stride should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
// we allow out-of-bound read from src in D1 dimension,
|
||||
// but we need to make sure dst stride is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
static_assert(thread_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
|
||||
static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover one line\n");
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const index_t thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
|
||||
const index_t thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
|
||||
|
||||
mSrcMyThreadOffset = SrcDesc{}.GetOffsetFromMultiIndex(
|
||||
src_block_data_multi_id_begin +
|
||||
Array<index_t, 2>{thread_id_d0, thread_id_d1 * DataPerRead});
|
||||
|
||||
mDstMyThreadOffset = DstDesc{}.GetOffsetFromMultiIndex(
|
||||
dst_block_data_multi_id_begin +
|
||||
Array<index_t, 2>{thread_id_d0, thread_id_d1 * DataPerRead});
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
auto f_copy = [&](index_t iloop) {
|
||||
*(reinterpret_cast<vector_t*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + mSrcMyThreadOffset +
|
||||
iloop * src_loop_stride));
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
f_copy(iloop);
|
||||
}
|
||||
|
||||
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
f_copy(nloop_d0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ constexpr index_t GetRegisterClipboardSize() const
|
||||
{
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
return DataPerRead * (L0 + thread_per_d0 - 1) / thread_per_d0;
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
auto f_copy = [&](index_t iloop) {
|
||||
*(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
|
||||
*(reinterpret_cast<const vector_t*>(
|
||||
&p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
f_copy(iloop);
|
||||
}
|
||||
|
||||
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
f_copy(nloop_d0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
auto f_copy = [&](index_t iloop) {
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]));
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
f_copy(iloop);
|
||||
}
|
||||
|
||||
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
f_copy(nloop_d0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if CK_USE_AMD_INLINE_ASM
|
||||
__device__ void RunLoadRegisterClipboard_asm(const Float* __restrict__ p_src,
|
||||
Float* p_clipboard) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
auto f_copy = [&](index_t iloop) {
|
||||
#if 0
|
||||
*(reinterpret_cast<vector_t*>(&p_clipboard[iloop * DataPerRead])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_src[mSrcMyThreadOffset +
|
||||
iloop * src_loop_stride]));
|
||||
#else
|
||||
static_assert(is_same<float, Float>::value && DataPerRead == 4,
|
||||
"global_load is only for float4");
|
||||
|
||||
global_load(reinterpret_cast<vector_t&>(p_clipboard[iloop * DataPerRead]),
|
||||
reinterpret_cast<const vector_t*>(
|
||||
&p_src[mSrcMyThreadOffset + iloop * src_loop_stride]));
|
||||
#endif
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
f_copy(iloop);
|
||||
}
|
||||
|
||||
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
f_copy(nloop_d0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterClipboard_asm(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr index_t thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr index_t src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr index_t dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
auto f_copy = [&](index_t iloop) {
|
||||
#if 0
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_clipboard[iloop * DataPerRead]);
|
||||
#else
|
||||
static_assert(is_same<float, Float>::value && DataPerRead == 4,
|
||||
"ds_write_b128 is only for float4");
|
||||
|
||||
ds_write_b128(reinterpret_cast<const vector_t&>(p_clipboard[iloop * DataPerRead]),
|
||||
&p_dst[mDstMyThreadOffset + iloop * dst_loop_stride]);
|
||||
#endif
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
f_copy(iloop);
|
||||
}
|
||||
|
||||
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr index_t tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
f_copy(nloop_d0);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,378 @@
|
||||
#ifndef CK_BLOCKWISE_3D_TENSOR_OP_HPP
|
||||
#define CK_BLOCKWISE_3D_TENSOR_OP_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
index_t DataPerRead>
|
||||
struct Blockwise3dTensorCopy1
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
__device__ constexpr Blockwise3dTensorCopy1()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
static_assert(DataPerRead == 1 ||
|
||||
(SrcDesc{}.GetStride(I2) == 1 && DstDesc{}.GetStride(I2) == 1),
|
||||
"wrong! only support stride2 == 1 if DataPerRead > 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I1) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I1) % DataPerRead == 0,
|
||||
"src and dst stride1 should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
// we allow out-of-bound read from src in D3 dimension,
|
||||
// but we need to make sure dst stride2 is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t read_per_d2 = math::integer_divide_ceil(L2, DataPerRead);
|
||||
|
||||
static_assert(read_per_d2 * DataPerRead <= DstDesc{}.GetStride(I1),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t read_per_d2 = math::integer_divide_ceil(L2, DataPerRead);
|
||||
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, L1, read_per_d2>{});
|
||||
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
auto f_copy = [&](index_t is) {
|
||||
index_t did[3];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
is -= did[1] * ref_desc.GetStride(I1);
|
||||
|
||||
did[2] = is / ref_desc.GetStride(I2);
|
||||
|
||||
const index_t src_index =
|
||||
src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2] * DataPerRead);
|
||||
const index_t dst_index =
|
||||
dst_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2] * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + src_index));
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
f_copy(is);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
f_copy(is);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// starting point need to be aligned to float4 or float2 or float
|
||||
// stride3 need to be 1 for both source and destination
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
class ThreadPerDims,
|
||||
index_t DataPerRead>
|
||||
struct Blockwise3dTensorCopy3
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__ Blockwise3dTensorCopy3()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
static_assert(DataPerRead == 1 ||
|
||||
(SrcDesc{}.GetStride(I2) == 1 && DstDesc{}.GetStride(I2) == 1),
|
||||
"wrong! only support stride3 == 1 if DataPerRead > 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(
|
||||
SrcDesc{}.GetStride(I1) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I1) % DataPerRead == 0,
|
||||
"wrong! src and dst stride1 should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
|
||||
// we allow out-of-bound read from src in D2 dimension,
|
||||
// but we need to make sure dst stride is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
|
||||
|
||||
static_assert(nloop_d2 * thread_per_d2 * DataPerRead <= DstDesc{}.GetStride(I1),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
|
||||
static_assert(L0 % thread_per_d0 == 0 && L1 % thread_per_d1 == 0,
|
||||
"wrong! L0, L1, L2 should be divided evenly!\n");
|
||||
|
||||
static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2,
|
||||
"wrrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
constexpr index_t num_active_thread =
|
||||
accumulate_on_sequence(ThreadPerDims{}, math::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor(ThreadPerDims{});
|
||||
const auto thread_multi_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
|
||||
mSrcMyThreadOffset = SrcDesc{}.GetOffsetFromMultiIndex(
|
||||
thread_multi_id[0], thread_multi_id[1], thread_multi_id[2] * DataPerRead);
|
||||
|
||||
mDstMyThreadOffset = DstDesc{}.GetOffsetFromMultiIndex(
|
||||
thread_multi_id[0], thread_multi_id[1], thread_multi_id[2] * DataPerRead);
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
const index_t src_offset =
|
||||
SrcDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2 * DataPerRead);
|
||||
|
||||
const index_t dst_offset =
|
||||
DstDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) = *(
|
||||
reinterpret_cast<const vector_t*>(&p_src[src_offset + mSrcMyThreadOffset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterClipboardSize()
|
||||
{
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
|
||||
|
||||
return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2;
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
|
||||
|
||||
constexpr auto clipboard_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<nloop_d0, nloop_d1, nloop_d2 * DataPerRead>{});
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
const index_t src_offset =
|
||||
SrcDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2 * DataPerRead);
|
||||
|
||||
const index_t clipboard_offset = clipboard_desc.GetOffsetFromMultiIndex(
|
||||
iloop_d0, iloop_d1, iloop_d2 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_clipboard[clipboard_offset])) = *(
|
||||
reinterpret_cast<const vector_t*>(&p_src[src_offset + mSrcMyThreadOffset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = math::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
|
||||
|
||||
constexpr auto clipboard_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<nloop_d0, nloop_d1, nloop_d2 * DataPerRead>{});
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
const index_t clipboard_offset = clipboard_desc.GetOffsetFromMultiIndex(
|
||||
iloop_d0, iloop_d1, iloop_d2 * DataPerRead);
|
||||
|
||||
const index_t dst_offset =
|
||||
DstDesc{}.GetOffsetFromMultiIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_clipboard[clipboard_offset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,779 @@
|
||||
#ifndef CK_BLOCKWISE_4D_TENSOR_OP_HPP
|
||||
#define CK_BLOCKWISE_4D_TENSOR_OP_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize, class Float, class DstDesc, class F>
|
||||
__device__ void
|
||||
blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr auto desc = make_ConstantTensorDescriptor_packed(dst_desc.GetLengths());
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
|
||||
print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr index_t NLoop = desc.GetElementSize() / BlockSize;
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
const index_t did0 = is / desc.GetStride(I0);
|
||||
|
||||
is -= did0 * desc.GetStride(I0);
|
||||
|
||||
const index_t did1 = is / desc.GetStride(I1);
|
||||
|
||||
is -= did1 * desc.GetStride(I1);
|
||||
|
||||
const index_t did2 = is / desc.GetStride(I2);
|
||||
|
||||
is -= did2 * desc.GetStride(I2);
|
||||
|
||||
const index_t did3 = is / desc.GetStride(I3);
|
||||
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1, did2, did3);
|
||||
|
||||
f(p_dst[dindex]);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < desc.GetElementSize())
|
||||
{
|
||||
const index_t did0 = is / desc.GetStride(I0);
|
||||
|
||||
is -= did0 * desc.GetStride(I0);
|
||||
|
||||
const index_t did1 = is / desc.GetStride(I1);
|
||||
|
||||
is -= did1 * desc.GetStride(I1);
|
||||
|
||||
const index_t did2 = is / desc.GetStride(I2);
|
||||
|
||||
is -= did2 * desc.GetStride(I2);
|
||||
|
||||
const index_t did3 = is / desc.GetStride(I3);
|
||||
|
||||
const index_t dindex = dst_desc.GetOffsetFromMultiIndex(did0, did1, did2, did3);
|
||||
|
||||
f(p_dst[dindex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
|
||||
// TODO: in order to optimize mem access for different mem type,
|
||||
// need to write specialized version
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src,
|
||||
class F>
|
||||
__device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src,
|
||||
F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t IR0 = MapDst2Src{}.Get(I0);
|
||||
constexpr index_t IR1 = MapDst2Src{}.Get(I1);
|
||||
constexpr index_t IR2 = MapDst2Src{}.Get(I2);
|
||||
constexpr index_t IR3 = MapDst2Src{}.Get(I3);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor_packed(SrcOpLengths{});
|
||||
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
is -= did[1] * ref_desc.GetStride(I1);
|
||||
|
||||
did[2] = is / ref_desc.GetStride(I2);
|
||||
|
||||
is -= did[2] * ref_desc.GetStride(I2);
|
||||
|
||||
did[3] = is / ref_desc.GetStride(I3);
|
||||
|
||||
const index_t src_index = src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3]);
|
||||
|
||||
const index_t dst_index =
|
||||
dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
|
||||
|
||||
f(p_src[src_index], p_dst[dst_index]);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
is -= did[1] * ref_desc.GetStride(I1);
|
||||
|
||||
did[2] = is / ref_desc.GetStride(I2);
|
||||
|
||||
is -= did[2] * ref_desc.GetStride(I2);
|
||||
|
||||
did[3] = is / ref_desc.GetStride(I3);
|
||||
|
||||
const index_t src_index =
|
||||
src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3]);
|
||||
|
||||
const index_t dst_index =
|
||||
dst_desc.GetOffsetFromMultiIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
|
||||
|
||||
f(p_src[src_index], p_dst[dst_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t BlockSize, class Float, class DstDesc>
|
||||
__device__ void blockwise_4d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
|
||||
{
|
||||
auto f_set_zero = [](Float& v) { v = Float(0); };
|
||||
|
||||
blockwise_4d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
{
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
|
||||
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
index_t DataPerRead>
|
||||
struct Blockwise4dTensorCopy1
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
__device__ constexpr Blockwise4dTensorCopy1()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(DataPerRead == 1 ||
|
||||
(SrcDesc{}.GetStride(I3) == 1 && DstDesc{}.GetStride(I3) == 1),
|
||||
"wrong! only support stride3 == 1 if DataPerRead > 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I2) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I2) % DataPerRead == 0,
|
||||
"src and dst stride2 should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
// we allow out-of-bound read from src in D3 dimension,
|
||||
// but we need to make sure dst stride2 is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
constexpr index_t read_per_d3 = math::integer_divide_ceil(L3, DataPerRead);
|
||||
|
||||
static_assert(read_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr index_t read_per_d3 = math::integer_divide_ceil(L3, DataPerRead);
|
||||
|
||||
constexpr auto ref_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<L0, L1, L2, read_per_d3>{});
|
||||
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
auto f_copy = [&](index_t is) {
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
is -= did[1] * ref_desc.GetStride(I1);
|
||||
|
||||
did[2] = is / ref_desc.GetStride(I2);
|
||||
|
||||
is -= did[2] * ref_desc.GetStride(I2);
|
||||
|
||||
did[3] = is / ref_desc.GetStride(I3);
|
||||
|
||||
const index_t src_index =
|
||||
src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3] * DataPerRead);
|
||||
const index_t dst_index =
|
||||
dst_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3] * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + src_index));
|
||||
};
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
f_copy(is);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
f_copy(is);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class DstOpLengths,
|
||||
class GlobalLowerPads>
|
||||
struct BlockwiseChwnTensorCopyPadded
|
||||
{
|
||||
__device__ void Run(const Float* __restrict__ p_src,
|
||||
index_t c_block_data_begin,
|
||||
index_t ho_block_data_begin,
|
||||
index_t wo_block_data_begin,
|
||||
index_t n_block_data_begin,
|
||||
Float* __restrict__ p_dst,
|
||||
index_t h_block_pad_low,
|
||||
index_t w_block_pad_low,
|
||||
index_t h_block_pad_up,
|
||||
index_t w_block_pad_up) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor_packed(DstOpLengths{});
|
||||
|
||||
constexpr auto h_global_pad_low = GlobalLowerPads{}.Get(I0);
|
||||
constexpr auto w_global_pad_low = GlobalLowerPads{}.Get(I1);
|
||||
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
const Float* p_src_tmp = p_src +
|
||||
src_desc.GetOffsetFromMultiIndex(
|
||||
c_block_data_begin,
|
||||
(ho_block_data_begin + h_block_pad_low) - h_global_pad_low,
|
||||
(wo_block_data_begin + w_block_pad_low) - w_global_pad_low,
|
||||
n_block_data_begin);
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(src_desc, "src_desc: ");
|
||||
print_ConstantTensorDescriptor(dst_desc, "dst_desc: ");
|
||||
print_ConstantTensorDescriptor(ref_desc, "ref_desc: ");
|
||||
|
||||
printf("%u %u, \t"
|
||||
"h_global_pad_low %u w_global_pad_low %u \t"
|
||||
"h_block_pad_low %u w_block_pad_low %u h_block_pad_up %u w_block_pad_up %u \t"
|
||||
"\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
h_global_pad_low,
|
||||
w_global_pad_low,
|
||||
h_block_pad_low,
|
||||
w_block_pad_low,
|
||||
h_block_pad_up,
|
||||
w_block_pad_up);
|
||||
}
|
||||
#endif
|
||||
|
||||
for(index_t iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
|
||||
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
is -= did[1] * ref_desc.GetStride(I1);
|
||||
|
||||
did[2] = is / ref_desc.GetStride(I2);
|
||||
|
||||
is -= did[2] * ref_desc.GetStride(I2);
|
||||
|
||||
did[3] = is / ref_desc.GetStride(I3);
|
||||
|
||||
const index_t bindex = dst_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3]);
|
||||
|
||||
p_dst[bindex] =
|
||||
(did[1] < h_block_pad_low || did[1] + h_block_pad_up >= ref_desc.GetLength(I1) ||
|
||||
did[2] < w_block_pad_low || did[2] + w_block_pad_up >= ref_desc.GetLength(I2))
|
||||
? Float(0)
|
||||
: p_src_tmp[src_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3])];
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
index_t did[4];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
is -= did[1] * ref_desc.GetStride(I1);
|
||||
|
||||
did[2] = is / ref_desc.GetStride(I2);
|
||||
|
||||
is -= did[2] * ref_desc.GetStride(I2);
|
||||
|
||||
did[3] = is / ref_desc.GetStride(I3);
|
||||
|
||||
const index_t bindex =
|
||||
dst_desc.GetOffsetFromMultiIndex(did[0], did[1], did[2], did[3]);
|
||||
|
||||
p_dst[bindex] =
|
||||
(did[1] < h_block_pad_low ||
|
||||
did[1] + h_block_pad_up >= ref_desc.GetLength(I1) ||
|
||||
did[2] < w_block_pad_low || did[2] + w_block_pad_up >= ref_desc.GetLength(I2))
|
||||
? Float(0)
|
||||
: p_src_tmp[src_desc.GetOffsetFromMultiIndex(
|
||||
did[0], did[1], did[2], did[3])];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// starting point need to be aligned to float4 or float2 or float
|
||||
// stride3 need to be 1 for both source and destination
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
class ThreadPerDims,
|
||||
index_t DataPerRead>
|
||||
struct Blockwise4dTensorCopy3
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__ Blockwise4dTensorCopy3()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(DataPerRead == 1 ||
|
||||
(SrcDesc{}.GetStride(I3) == 1 && DstDesc{}.GetStride(I3) == 1),
|
||||
"wrong! only support stride3 == 1 if DataPerRead > 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(
|
||||
SrcDesc{}.GetStride(I2) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I2) % DataPerRead == 0,
|
||||
"wrong! src and dst stride2 should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
|
||||
|
||||
// we allow out-of-bound read from src in D3 dimension,
|
||||
// but we need to make sure dst stride is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
static_assert(nloop_d3 * thread_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
|
||||
static_assert(L0 % thread_per_d0 == 0 && L1 % thread_per_d1 == 0 && L2 % thread_per_d2 == 0,
|
||||
"wrong! L0, L1, L2 should be divided evenly!\n");
|
||||
|
||||
static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3,
|
||||
"wrrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
constexpr index_t num_active_thread =
|
||||
accumulate_on_sequence(ThreadPerDims{}, math::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(ThreadPerDims{});
|
||||
const auto thread_multi_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
|
||||
mSrcMyThreadOffset = SrcDesc{}.GetOffsetFromMultiIndex(thread_multi_id[0],
|
||||
thread_multi_id[1],
|
||||
thread_multi_id[2],
|
||||
thread_multi_id[3] * DataPerRead);
|
||||
|
||||
mDstMyThreadOffset = DstDesc{}.GetOffsetFromMultiIndex(thread_multi_id[0],
|
||||
thread_multi_id[1],
|
||||
thread_multi_id[2],
|
||||
thread_multi_id[3] * DataPerRead);
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
|
||||
|
||||
constexpr index_t num_active_thread =
|
||||
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
|
||||
{
|
||||
const index_t src_offset = SrcDesc{}.GetOffsetFromMultiIndex(
|
||||
iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2,
|
||||
iloop_d3 * thread_per_d3 * DataPerRead);
|
||||
|
||||
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(
|
||||
iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2,
|
||||
iloop_d3 * thread_per_d3 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) =
|
||||
*(reinterpret_cast<const vector_t*>(
|
||||
&p_src[src_offset + mSrcMyThreadOffset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ constexpr index_t GetRegisterClipboardSize() const
|
||||
{
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2 * nloop_d3;
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
|
||||
|
||||
constexpr index_t num_active_thread =
|
||||
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
constexpr auto clipboard_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{});
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
|
||||
{
|
||||
const index_t src_offset = SrcDesc{}.GetOffsetFromMultiIndex(
|
||||
iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2,
|
||||
iloop_d3 * thread_per_d3 * DataPerRead);
|
||||
|
||||
const index_t clipboard_offset = clipboard_desc.GetOffsetFromMultiIndex(
|
||||
iloop_d0, iloop_d1, iloop_d2, iloop_d3 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_clipboard[clipboard_offset])) =
|
||||
*(reinterpret_cast<const vector_t*>(
|
||||
&p_src[src_offset + mSrcMyThreadOffset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
constexpr index_t thread_per_d3 = ThreadPerDims{}.Get(I3);
|
||||
|
||||
constexpr index_t num_active_thread =
|
||||
thread_per_d0 * thread_per_d1 * thread_per_d2 * thread_per_d3;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr index_t nloop_d3 = math::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
constexpr auto clipboard_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{});
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
|
||||
{
|
||||
const index_t clipboard_offset = clipboard_desc.GetOffsetFromMultiIndex(
|
||||
iloop_d0, iloop_d1, iloop_d2, iloop_d3 * DataPerRead);
|
||||
|
||||
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(
|
||||
iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2,
|
||||
iloop_d3 * thread_per_d3 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_clipboard[clipboard_offset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
struct Blockwise4dTensorCopyReorder1
|
||||
{
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
|
||||
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,529 @@
|
||||
#ifndef CK_BLOCKWISE_BATCHED_GEMM_HPP
|
||||
#define CK_BLOCKWISE_BATCHED_GEMM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "threadwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
class ThreadMatrixC,
|
||||
index_t BlockMatrixStrideA,
|
||||
index_t BlockMatrixStrideB,
|
||||
index_t ThreadMatrixStrideC,
|
||||
index_t BatchSize,
|
||||
index_t MPerThreadSubC,
|
||||
index_t NPerThreadSubC,
|
||||
index_t MLevel0Cluster,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t KPerThreadLoop,
|
||||
index_t BatchPerThread,
|
||||
index_t DataPerReadA,
|
||||
index_t DataPerReadB>
|
||||
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
{
|
||||
index_t mMyThreadOffsetA = 0;
|
||||
index_t mMyThreadOffsetB = 0;
|
||||
|
||||
struct MatrixIndex
|
||||
{
|
||||
index_t batch;
|
||||
index_t row;
|
||||
index_t col;
|
||||
};
|
||||
|
||||
__device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2()
|
||||
{
|
||||
static_assert(BatchSize % BatchPerThread == 0,
|
||||
"wrong! BatchSize is not dividable by BatchPerThread");
|
||||
|
||||
constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
|
||||
|
||||
constexpr index_t ThreadPerLevel1Cluster =
|
||||
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
|
||||
|
||||
static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster,
|
||||
"wrong! wrong blocksize\n");
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
|
||||
"wrong! K dimension not consistent\n");
|
||||
|
||||
constexpr index_t M = a_block_mtx.NCol(); // A is transposed
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
|
||||
"wrong! Cannot evenly divide thread work among repeat \n");
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
|
||||
"wrong! Cannot evenly divide work among repeat\n");
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = M / MRepeat;
|
||||
constexpr index_t NPerLevel1Cluster = N / NRepeat;
|
||||
|
||||
static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
|
||||
(NPerLevel1Cluster % NLevel1Cluster == 0),
|
||||
"wrong! Cannot evenly divide work among Level1Cluster\n");
|
||||
|
||||
constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
|
||||
|
||||
static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
|
||||
(NPerLevel0Cluster % NLevel0Cluster == 0),
|
||||
"wrong! Cannot evenly divide work among Level0Cluster\n");
|
||||
|
||||
static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
|
||||
(NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
|
||||
"wrong! thread work size is wrong\n");
|
||||
|
||||
const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA +
|
||||
a_block_mtx.GetOffsetFromMultiIndex(0, c_thread_mtx_index.row);
|
||||
|
||||
mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB +
|
||||
b_block_mtx.GetOffsetFromMultiIndex(0, c_thread_mtx_index.col);
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
|
||||
print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
|
||||
print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");
|
||||
|
||||
printf("%u %u, %u %u %u, %u %u\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
c_thread_mtx_index.batch,
|
||||
c_thread_mtx_index.row,
|
||||
c_thread_mtx_index.col,
|
||||
mMyThreadOffsetA,
|
||||
mMyThreadOffsetB);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
|
||||
{
|
||||
constexpr index_t ThreadPerLevel1Cluster =
|
||||
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
|
||||
|
||||
index_t batch_work_id = thread_id / ThreadPerLevel1Cluster;
|
||||
index_t cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster;
|
||||
|
||||
index_t level1_id = cluster_id / ThreadPerLevel0Cluster;
|
||||
index_t level1_m_id = level1_id / NLevel1Cluster;
|
||||
index_t level1_n_id = level1_id % NLevel1Cluster;
|
||||
|
||||
index_t level0_id = cluster_id % ThreadPerLevel0Cluster;
|
||||
index_t level0_m_id = level0_id / NLevel0Cluster;
|
||||
index_t level0_n_id = level0_id % NLevel0Cluster;
|
||||
|
||||
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
|
||||
|
||||
return MatrixIndex{batch_work_id * BatchPerThread,
|
||||
level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
|
||||
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
|
||||
}
|
||||
|
||||
// this should be optimized away because input will be known at compile time
|
||||
__device__ static MatrixIndex
|
||||
GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
|
||||
{
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
index_t m_repeat = m_in_c / MPerThreadSubC;
|
||||
index_t n_repeat = n_in_c / NPerThreadSubC;
|
||||
|
||||
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
|
||||
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
|
||||
|
||||
return MatrixIndex{batch_in_c,
|
||||
m_repeat * MPerLevel1Cluster + m_in_sub_c,
|
||||
n_repeat * NPerLevel1Cluster + n_in_sub_c};
|
||||
}
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
FloatC* __restrict__ p_c_thread) const
|
||||
{
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr auto False = integral_constant<bool, false>{};
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
// A is transposed, b is not
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
// thread A-sub, B-sub for copy
|
||||
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
// loop over k
|
||||
#pragma unroll
|
||||
for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
|
||||
{
|
||||
// loop over batch
|
||||
#pragma unroll
|
||||
for(index_t ib = 0; ib < BatchPerThread; ++ib)
|
||||
{
|
||||
// read next batch of a, b
|
||||
if(BlockMatrixStrideA != 0 or ib == 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
a_block_mtx,
|
||||
p_a_block +
|
||||
a_block_mtx.GetOffsetFromMultiIndex(k_begin,
|
||||
m_repeat * MPerLevel1Cluster) +
|
||||
ib * BlockMatrixStrideA + mMyThreadOffsetA,
|
||||
a_thread_mtx,
|
||||
p_a_thread +
|
||||
a_thread_mtx.GetOffsetFromMultiIndex(0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
}
|
||||
|
||||
if(BlockMatrixStrideB != 0 or ib == 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
b_block_mtx,
|
||||
p_b_block +
|
||||
b_block_mtx.GetOffsetFromMultiIndex(k_begin,
|
||||
n_repeat * NPerLevel1Cluster) +
|
||||
ib * BlockMatrixStrideB + mMyThreadOffsetB,
|
||||
b_thread_mtx,
|
||||
p_b_thread +
|
||||
b_thread_mtx.GetOffsetFromMultiIndex(0, n_repeat * NPerThreadSubC),
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
printf("a: %f %f %f %f %f %f %f %f, b: %f %f %f %f %f %f %f %f\n",
|
||||
p_a_thread[0],
|
||||
p_a_thread[1],
|
||||
p_a_thread[2],
|
||||
p_a_thread[3],
|
||||
p_a_thread[4],
|
||||
p_a_thread[5],
|
||||
p_a_thread[6],
|
||||
p_a_thread[7],
|
||||
p_b_thread[0],
|
||||
p_b_thread[1],
|
||||
p_b_thread[2],
|
||||
p_b_thread[3],
|
||||
p_b_thread[4],
|
||||
p_b_thread[5],
|
||||
p_b_thread[6],
|
||||
p_b_thread[7]);
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_gemm(a_thread_mtx,
|
||||
True,
|
||||
p_a_thread,
|
||||
b_thread_mtx,
|
||||
False,
|
||||
p_b_thread,
|
||||
c_thread_mtx,
|
||||
False,
|
||||
p_c_thread + ib * ThreadMatrixStrideC);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if CK_USE_AMD_INLINE_ASM
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run_asm(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
FloatC* __restrict__ p_c_thread) const
|
||||
{
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t K = a_block_mtx.NRow(); // A is transposed
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
// A is transposed, b is not
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
// thread A-sub, B-sub for copy
|
||||
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
// assertion for inline asm
|
||||
static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
|
||||
is_same<FloatC, float>::value,
|
||||
"Run_asm only deal with float\n");
|
||||
|
||||
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
|
||||
MPerThread == 8 && NPerThread == 8,
|
||||
"Run_asm cannot deal with this GEMM shape yet\n");
|
||||
|
||||
static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");
|
||||
|
||||
static_assert(
|
||||
BlockMatrixStrideA == 0 && BatchPerThread == 1,
|
||||
"Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n");
|
||||
|
||||
using Float4 = vector_type<float, 4>::MemoryType;
|
||||
|
||||
Float4* reg_a = (Float4*)(p_a_thread);
|
||||
Float4* reg_b = (Float4*)(p_b_thread);
|
||||
Float4* reg_c = (Float4*)(p_c_thread);
|
||||
|
||||
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
|
||||
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
|
||||
reg_b[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_b_block[b_block_mtx.GetOffsetFromMultiIndex(0, NPerLevel1Cluster) +
|
||||
mMyThreadOffsetB]);
|
||||
reg_a[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_a_block[a_block_mtx.GetOffsetFromMultiIndex(0, MPerLevel1Cluster) +
|
||||
mMyThreadOffsetA]);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
|
||||
#pragma unroll
|
||||
for(index_t k = 1; k < K; ++k)
|
||||
{
|
||||
reg_a[0] = *reinterpret_cast<const Float4*>(
|
||||
&p_a_block[a_block_mtx.GetOffsetFromMultiIndex(k, 0) + mMyThreadOffsetA]);
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
reg_b[0] = *reinterpret_cast<const Float4*>(
|
||||
&p_b_block[b_block_mtx.GetOffsetFromMultiIndex(k, 0) + mMyThreadOffsetB]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
reg_b[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_b_block[b_block_mtx.GetOffsetFromMultiIndex(k, NPerLevel1Cluster) +
|
||||
mMyThreadOffsetB]);
|
||||
reg_a[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_a_block[a_block_mtx.GetOffsetFromMultiIndex(k, MPerLevel1Cluster) +
|
||||
mMyThreadOffsetA]);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
}
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
}
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run_asm_v2(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
FloatC* __restrict__ p_c_thread) const
|
||||
{
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow(); // A is transposed
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
// A is transposed, b is not
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
// thread A-sub, B-sub for copy
|
||||
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
// assertion for inline asm
|
||||
static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
|
||||
is_same<FloatC, float>::value,
|
||||
"Run_asm only deal with float\n");
|
||||
|
||||
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
|
||||
MPerThread == 8 && NPerThread == 8,
|
||||
"Run_asm cannot deal with this GEMM shape yet\n");
|
||||
|
||||
static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");
|
||||
|
||||
static_assert(
|
||||
BlockMatrixStrideA == 0 && BatchPerThread == 1,
|
||||
"Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n");
|
||||
|
||||
using Float4 = vector_type<float, 4>::MemoryType;
|
||||
|
||||
Float4* reg_a = (Float4*)(p_a_thread);
|
||||
Float4* reg_b = (Float4*)(p_b_thread);
|
||||
Float4* reg_c = (Float4*)(p_c_thread);
|
||||
|
||||
void* a_lds_loc = (void*)(p_a_block + mMyThreadOffsetA);
|
||||
void* b_lds_loc = (void*)(p_b_block + mMyThreadOffsetB);
|
||||
|
||||
constexpr index_t a_lds_row_stride = sizeof(float) * a_block_mtx.RowStride();
|
||||
constexpr index_t b_lds_row_stride = sizeof(float) * b_block_mtx.RowStride();
|
||||
constexpr index_t a_lds_cluster_col_stride = sizeof(float) * MPerLevel1Cluster;
|
||||
constexpr index_t b_lds_cluster_col_stride = sizeof(float) * NPerLevel1Cluster;
|
||||
|
||||
ds_read_b128(reg_a[0], a_lds_loc, 0);
|
||||
ds_read_b128(reg_b[0], b_lds_loc, 0);
|
||||
ds_read_b128(reg_b[1], b_lds_loc, b_lds_cluster_col_stride);
|
||||
ds_read_b128(reg_a[1], a_lds_loc, a_lds_cluster_col_stride);
|
||||
lgkmcnt(2);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
lgkmcnt(1);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
|
||||
#pragma unroll
|
||||
for(index_t k = 1; k < K; ++k)
|
||||
{
|
||||
ds_read_b128(reg_a[0], a_lds_loc, k * a_lds_row_stride);
|
||||
lgkmcnt(1);
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
ds_read_b128(reg_b[0], b_lds_loc, k * b_lds_row_stride);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
ds_read_b128(reg_b[1], b_lds_loc, b_lds_cluster_col_stride + k * b_lds_row_stride);
|
||||
ds_read_b128(reg_a[1], a_lds_loc, a_lds_cluster_col_stride + k * a_lds_row_stride);
|
||||
lgkmcnt(2);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
lgkmcnt(1);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
}
|
||||
|
||||
lgkmcnt(0);
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
|
||||
__device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
|
||||
FloatC* __restrict__ p_c_block) const
|
||||
{
|
||||
constexpr auto c_block_mtx = BlockMatrixC{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t c_thread_offset =
|
||||
c_thread_mtx_begin.batch * BlockMatrixStrideC +
|
||||
c_block_mtx.GetOffsetFromMultiIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col);
|
||||
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
c_thread_sub_mtx,
|
||||
p_c_thread +
|
||||
c_thread_sub_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster,
|
||||
n_repeat * NPerLevel1Cluster),
|
||||
c_block_mtx,
|
||||
p_c_block +
|
||||
c_block_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster,
|
||||
n_repeat * NPerLevel1Cluster) +
|
||||
c_thread_offset,
|
||||
c_thread_sub_mtx.GetLengths());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
#endif
|
||||
433
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
Normal file
433
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
Normal file
@@ -0,0 +1,433 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_HPP
|
||||
#define CK_BLOCKWISE_GEMM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "threadwise_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// if following number are power of 2, index calculation shall be greatly reduced:
|
||||
// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
|
||||
template <index_t BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
class ThreadMatrixC,
|
||||
index_t MPerThreadSubC,
|
||||
index_t NPerThreadSubC,
|
||||
index_t MLevel0Cluster,
|
||||
index_t NLevel0Cluster,
|
||||
index_t MLevel1Cluster,
|
||||
index_t NLevel1Cluster,
|
||||
index_t KPerThreadLoop,
|
||||
index_t DataPerReadA,
|
||||
index_t DataPerReadB>
|
||||
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
{
|
||||
struct MatrixIndex
|
||||
{
|
||||
index_t row;
|
||||
index_t col;
|
||||
};
|
||||
|
||||
index_t mMyThreadOffsetA;
|
||||
index_t mMyThreadOffsetB;
|
||||
|
||||
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
|
||||
{
|
||||
constexpr index_t ThreadPerLevel1Cluster =
|
||||
MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;
|
||||
|
||||
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
|
||||
|
||||
static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(),
|
||||
"wrong! K dimension not consistent\n");
|
||||
|
||||
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
|
||||
constexpr index_t N = BlockMatrixB::NCol();
|
||||
constexpr index_t K = BlockMatrixA::NRow();
|
||||
|
||||
static_assert(M % (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster) == 0 &&
|
||||
N % (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster) == 0,
|
||||
"wrong! Cannot evenly divide work among\n");
|
||||
|
||||
static_assert(is_same_type(ThreadMatrixC::GetLengths(), GetThreadMatrixCLengths()),
|
||||
"wrong! ThreadMatrixC lengths is wrong");
|
||||
|
||||
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
mMyThreadOffsetA = BlockMatrixA::GetOffsetFromMultiIndex(0, c_thread_mtx_index.row);
|
||||
mMyThreadOffsetB = BlockMatrixB::GetOffsetFromMultiIndex(0, c_thread_mtx_index.col);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetThreadMatrixCLengths()
|
||||
{
|
||||
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
|
||||
constexpr index_t N = BlockMatrixB::NCol();
|
||||
|
||||
constexpr index_t MRepeat = M / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster);
|
||||
constexpr index_t NRepeat = N / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster);
|
||||
|
||||
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
|
||||
}
|
||||
|
||||
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
|
||||
{
|
||||
constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
|
||||
|
||||
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
|
||||
index_t level1_m_id = level1_id / NLevel1Cluster;
|
||||
index_t level1_n_id = level1_id % NLevel1Cluster;
|
||||
|
||||
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
|
||||
index_t level0_m_id = level0_id / NLevel0Cluster;
|
||||
index_t level0_n_id = level0_id % NLevel0Cluster;
|
||||
|
||||
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
|
||||
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
|
||||
|
||||
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
|
||||
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
|
||||
}
|
||||
|
||||
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
|
||||
index_t n_in_c)
|
||||
{
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
index_t m_repeat = m_in_c / MPerThreadSubC;
|
||||
index_t n_repeat = n_in_c / NPerThreadSubC;
|
||||
|
||||
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
|
||||
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
|
||||
|
||||
return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
|
||||
n_repeat * NPerLevel1Cluster + n_in_sub_c};
|
||||
}
|
||||
|
||||
#if CK_USE_AMD_INLINE_ASM
|
||||
// TODO: this is not working correctly
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run_asm(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
FloatC* __restrict__ p_c_thread) const
|
||||
{
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr auto False = integral_constant<bool, false>{};
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
// thread A-sub, B-sub for copy
|
||||
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
// assertion for inline asm
|
||||
static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
|
||||
is_same<FloatC, float>::value,
|
||||
"Run_asm only deal with float\n");
|
||||
|
||||
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
|
||||
MPerThread == 8 && NPerThread == 8,
|
||||
"Run_asm cannot deal with this GEMM shape yet\n");
|
||||
|
||||
static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");
|
||||
|
||||
using Float4 = vector_type<float, 4>::MemoryType;
|
||||
|
||||
Float4* reg_a = (Float4*)(p_a_thread);
|
||||
Float4* reg_b = (Float4*)(p_b_thread);
|
||||
Float4* reg_c = (Float4*)(p_c_thread);
|
||||
|
||||
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
|
||||
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
|
||||
reg_b[1] =
|
||||
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
|
||||
reg_a[1] =
|
||||
*reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
#pragma unroll
|
||||
for(index_t k = 1; k < K; ++k)
|
||||
{
|
||||
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
reg_b[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
|
||||
reg_a[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
}
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run(const FloatA* const __restrict__ p_a_block,
|
||||
const FloatB* const __restrict__ p_b_block,
|
||||
FloatC* const __restrict__ p_c_thread) const
|
||||
{
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr auto False = integral_constant<bool, false>{};
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
// thread A-sub, B-sub for copy
|
||||
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
const FloatA* const p_a_block_thread_offset = p_a_block + mMyThreadOffsetA;
|
||||
|
||||
#pragma unroll
|
||||
// loop over k
|
||||
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
{
|
||||
#pragma unroll
|
||||
// copy A-sub to form A
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
a_block_mtx,
|
||||
p_a_block +
|
||||
a_block_mtx.GetOffsetFromMultiIndex(k_begin, m_repeat * MPerLevel1Cluster) +
|
||||
mMyThreadOffsetA,
|
||||
a_thread_mtx,
|
||||
p_a_thread + a_thread_mtx.GetOffsetFromMultiIndex(0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
// copy B-sub to form B
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
b_block_mtx,
|
||||
p_b_block +
|
||||
b_block_mtx.GetOffsetFromMultiIndex(k_begin, n_repeat * NPerLevel1Cluster) +
|
||||
mMyThreadOffsetB,
|
||||
b_thread_mtx,
|
||||
p_b_thread + b_thread_mtx.GetOffsetFromMultiIndex(0, n_repeat * NPerThreadSubC),
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
threadwise_gemm(a_thread_mtx,
|
||||
True,
|
||||
p_a_thread,
|
||||
b_thread_mtx,
|
||||
False,
|
||||
p_b_thread,
|
||||
c_thread_mtx,
|
||||
False,
|
||||
p_c_thread);
|
||||
}
|
||||
}
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run_RegisterDoubleBuffer(FloatA* const p_a_block,
|
||||
FloatB* const p_b_block,
|
||||
FloatC* p_c_thread) const
|
||||
{
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr auto False = integral_constant<bool, false>{};
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
// thread A-sub, B-sub for copy
|
||||
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
// register
|
||||
FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
// preload A, B
|
||||
#pragma unroll
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{ // copy A-sub to form A
|
||||
threadwise_matrix_copy(a_block_mtx,
|
||||
p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
|
||||
a_thread_sub_mtx,
|
||||
p_a_thread_0 + m_repeat * MPerThreadSubC,
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{ // copy B-sub to form B
|
||||
threadwise_matrix_copy(b_block_mtx,
|
||||
p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
|
||||
b_thread_sub_mtx,
|
||||
p_b_thread_0 + n_repeat * NPerThreadSubC,
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
|
||||
bool even_loop = true;
|
||||
|
||||
#pragma unroll
|
||||
for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
|
||||
k_begin += KPerThreadLoop, even_loop = !even_loop)
|
||||
{ // loop over k
|
||||
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
|
||||
FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;
|
||||
|
||||
FloatA* p_a_thread_next = even_loop ? p_a_thread_1 : p_a_thread_0;
|
||||
FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0;
|
||||
|
||||
// preload next A, B
|
||||
#pragma unroll
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{ // copy A-sub to form A
|
||||
threadwise_matrix_copy(a_block_mtx,
|
||||
p_a_block + mMyThreadOffsetA +
|
||||
(k_begin + 1) * a_block_mtx.RowStride() +
|
||||
m_repeat * MPerLevel1Cluster,
|
||||
a_thread_sub_mtx,
|
||||
p_a_thread_next + m_repeat * MPerThreadSubC,
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{ // copy B-sub to form B
|
||||
threadwise_matrix_copy(b_block_mtx,
|
||||
p_b_block + mMyThreadOffsetB +
|
||||
(k_begin + 1) * b_block_mtx.RowStride() +
|
||||
n_repeat * NPerLevel1Cluster,
|
||||
b_thread_sub_mtx,
|
||||
p_b_thread_next + n_repeat * NPerThreadSubC,
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
threadwise_gemm(a_thread_mtx,
|
||||
True,
|
||||
p_a_thread_now,
|
||||
b_thread_mtx,
|
||||
False,
|
||||
p_b_thread_now,
|
||||
c_thread_mtx,
|
||||
False,
|
||||
p_c_thread);
|
||||
}
|
||||
|
||||
// last loop
|
||||
{
|
||||
FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
|
||||
FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;
|
||||
|
||||
// C = A * B
|
||||
threadwise_gemm(a_thread_mtx,
|
||||
True,
|
||||
p_a_thread_now,
|
||||
b_thread_mtx,
|
||||
False,
|
||||
p_b_thread_now,
|
||||
c_thread_mtx,
|
||||
False,
|
||||
p_c_thread);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,401 @@
|
||||
#ifndef CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_HPP
|
||||
#define CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
|
||||
// memory layout (ordering of dimensions) can be different between src and dst
|
||||
// For now, only support SubLengths[...] == 1 on a merged dimension
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SliceLengths,
|
||||
class SubLengths,
|
||||
class DataClusterLengths,
|
||||
class ThreadClusterArrangeOrder,
|
||||
class SrcAccessOrder,
|
||||
class DstAccessOrder,
|
||||
index_t SrcDataPerRead,
|
||||
index_t DstDataPerWrite>
|
||||
struct BlockwiseGenericTensorSliceCopy_v1
|
||||
{
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
|
||||
static constexpr index_t nOriginalDimSrc =
|
||||
SrcDesc::GetOriginalTensorDescriptor().GetNumOfDimension();
|
||||
static constexpr index_t nOriginalDimDst =
|
||||
DstDesc::GetOriginalTensorDescriptor().GetNumOfDimension();
|
||||
|
||||
// per-thread offset
|
||||
index_t mThreadSrcOffset;
|
||||
index_t mThreadDstOffset;
|
||||
|
||||
// "mThreadSrcOriginalMultiId", "mThreadSrcPartialOffsets, "mThreadDstOriginalMultiId",
|
||||
// "mThreadDstPartialOffsets" are always calculated inside constructor, and would be
|
||||
// updated if slicing-window is moved. However, they will not be used if you always move
|
||||
// the slicing-window along a non-merged dimension. In that case, compiler should be
|
||||
// able to remove these calculation.
|
||||
// TODO: make sure compiler would actually remove them in that case
|
||||
|
||||
// partial offset in each (merged) dimension
|
||||
Array<index_t, nDim> mThreadSrcPartialOffsets;
|
||||
Array<index_t, nDim> mThreadDstPartialOffsets;
|
||||
|
||||
// multi-id of original tensor
|
||||
Array<index_t, nOriginalDimSrc> mThreadSrcOriginalMultiId;
|
||||
Array<index_t, nOriginalDimDst> mThreadDstOriginalMultiId;
|
||||
|
||||
__device__
|
||||
BlockwiseGenericTensorSliceCopy_v1(Array<index_t, nDim> src_block_data_multi_id_begin,
|
||||
Array<index_t, nDim> dst_block_data_multi_id_begin)
|
||||
{
|
||||
// check NDim consistency
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
|
||||
nDim == SubLengths::GetSize() && nDim == DataClusterLengths::GetSize() &&
|
||||
nDim == ThreadClusterArrangeOrder::GetSize() &&
|
||||
nDim == SrcAccessOrder::GetSize() && nDim == DstAccessOrder::GetSize(),
|
||||
"wrong");
|
||||
|
||||
// check thread arrange order and read/write access order are valid
|
||||
static_assert(is_valid_sequence_map<ThreadClusterArrangeOrder>::value &&
|
||||
is_valid_sequence_map<SrcAccessOrder>::value &&
|
||||
is_valid_sequence_map<DstAccessOrder>::value,
|
||||
"wrong!");
|
||||
|
||||
// thread cluster
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed(
|
||||
DataClusterLengths{}.ReorderGivenNew2Old(ThreadClusterArrangeOrder{}));
|
||||
|
||||
// BlockSize
|
||||
static_assert(BlockSize == thread_cluster_desc.GetElementSize(), "wrong! BlockSize");
|
||||
|
||||
// divide work
|
||||
constexpr auto data_per_cluster_per_dims = SubLengths{} * DataClusterLengths{};
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim_) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
|
||||
static_assert(SliceLengths::Get(IDim) % SubLengths::Get(IDim) == 0,
|
||||
"wrong! cannot evenly divide sliced tensor into sub-tensor");
|
||||
|
||||
static_assert(SliceLengths::Get(IDim) % data_per_cluster_per_dims.Get(IDim) == 0,
|
||||
"wrong! cannot evenly divide sliced tensor into cluster");
|
||||
});
|
||||
|
||||
constexpr auto repeat_lengths = SliceLengths{} / data_per_cluster_per_dims;
|
||||
|
||||
// for now, only support SubLengths.Get() == 1 on a merged dimension that constains
|
||||
// multiple original dimensions
|
||||
static_for<0, nDim, 1>{}([&](auto IDim_) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
|
||||
static_assert(SubLengths::Get(IDim) == 1 ||
|
||||
(!SrcDesc::ContainMultipleOriginalDimensions(IDim) &&
|
||||
!DstDesc::ContainMultipleOriginalDimensions(IDim)),
|
||||
"wrong! only surpport Sub-Length == 1 on a merged dimension");
|
||||
});
|
||||
|
||||
// calculate mThreadSrcOffset, mThreadDstOffset
|
||||
const auto thread_cluster_multi_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
|
||||
const auto data_cluster_multi_id =
|
||||
reorder_array_given_old2new(thread_cluster_multi_id, ThreadClusterArrangeOrder{});
|
||||
|
||||
const auto thread_data_multi_id_begin = data_cluster_multi_id * SubLengths{};
|
||||
|
||||
// original multi-id
|
||||
mThreadSrcOriginalMultiId = SrcDesc::GetOriginalMultiIndexFromMultiIndex(
|
||||
src_block_data_multi_id_begin + thread_data_multi_id_begin);
|
||||
|
||||
mThreadDstOriginalMultiId = DstDesc::GetOriginalMultiIndexFromMultiIndex(
|
||||
dst_block_data_multi_id_begin + thread_data_multi_id_begin);
|
||||
|
||||
// partial offset on each dimension
|
||||
static_for<0, nDim, 1>{}([&](auto IDim_) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
constexpr index_t idim = IDim.Get();
|
||||
|
||||
constexpr auto src_partial_original_dims =
|
||||
SrcDesc::GetContainedOriginalDimensions(IDim);
|
||||
|
||||
constexpr auto src_partial_original_desc =
|
||||
SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims);
|
||||
|
||||
mThreadSrcPartialOffsets(idim) = src_partial_original_desc.GetOffsetFromMultiIndex(
|
||||
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims));
|
||||
});
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim_) {
|
||||
constexpr auto IDim = decltype(IDim_){};
|
||||
constexpr index_t idim = IDim.Get();
|
||||
|
||||
constexpr auto dst_partial_original_dims =
|
||||
DstDesc::GetContainedOriginalDimensions(IDim);
|
||||
|
||||
constexpr auto dst_partial_original_desc =
|
||||
DstDesc::GetOriginalTensorDescriptor().Extract(dst_partial_original_dims);
|
||||
|
||||
mThreadDstPartialOffsets(idim) = dst_partial_original_desc.GetOffsetFromMultiIndex(
|
||||
extract_array(mThreadDstOriginalMultiId, dst_partial_original_dims));
|
||||
});
|
||||
|
||||
// complete offset
|
||||
mThreadSrcOffset = accumulate_on_array(
|
||||
mThreadSrcPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
|
||||
|
||||
mThreadDstOffset = accumulate_on_array(
|
||||
mThreadDstPartialOffsets, math::plus<index_t>{}, static_cast<index_t>(0));
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("id %5u %5u: "
|
||||
"src_block_data_multi_id_begin: %u %u %u %u, "
|
||||
"thread_cluster_multi_id: %u %u %u %u, "
|
||||
"data_cluster_multi_id: %u %u %u %u, "
|
||||
"thread_data_multi_id_begin: %u %u %u %u, "
|
||||
"mThreadSrcOffset %u, mThreadDstOffset %u \n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
src_block_data_multi_id_begin[0],
|
||||
src_block_data_multi_id_begin[1],
|
||||
src_block_data_multi_id_begin[2],
|
||||
src_block_data_multi_id_begin[3],
|
||||
thread_cluster_multi_id[0],
|
||||
thread_cluster_multi_id[1],
|
||||
thread_cluster_multi_id[2],
|
||||
thread_cluster_multi_id[3],
|
||||
data_cluster_multi_id[0],
|
||||
data_cluster_multi_id[1],
|
||||
data_cluster_multi_id[2],
|
||||
data_cluster_multi_id[3],
|
||||
thread_data_multi_id_begin[0],
|
||||
thread_data_multi_id_begin[1],
|
||||
thread_data_multi_id_begin[2],
|
||||
thread_data_multi_id_begin[3],
|
||||
mThreadSrcOffset,
|
||||
mThreadDstOffset);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterClipboardSize()
|
||||
{
|
||||
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * DataClusterLengths{});
|
||||
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(SubLengths{} * repeat_lengths);
|
||||
|
||||
return thread_tensor_desc.GetElementSpace();
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SubLengths{};
|
||||
|
||||
constexpr auto data_per_cluster_per_dims = thread_sub_tensor_lengths * DataClusterLengths{};
|
||||
|
||||
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * DataClusterLengths{});
|
||||
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
|
||||
#if 0
|
||||
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
|
||||
|
||||
const auto src_thread_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
|
||||
|
||||
const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
const index_t src_offset =
|
||||
SrcDesc{}.GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
|
||||
|
||||
const index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
#else // HIP compiler performs better with these codes
|
||||
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
|
||||
|
||||
constexpr auto src_thread_data_multi_id_begin =
|
||||
repeat_multi_id * data_per_cluster_per_dims;
|
||||
|
||||
constexpr auto clipboard_data_multi_id_begin =
|
||||
repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
constexpr index_t src_offset =
|
||||
SrcDesc::GetOffsetFromMultiIndex(src_thread_data_multi_id_begin);
|
||||
|
||||
constexpr index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
#endif
|
||||
|
||||
threadwise_generic_tensor_slice_copy_v1(SrcDesc{},
|
||||
p_src + src_offset + mThreadSrcOffset,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
thread_sub_tensor_lengths,
|
||||
SrcAccessOrder{},
|
||||
Number<SrcDataPerRead>{});
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SubLengths{};
|
||||
|
||||
constexpr auto data_per_cluster_per_dims = thread_sub_tensor_lengths * DataClusterLengths{};
|
||||
|
||||
constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * DataClusterLengths{});
|
||||
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_sub_tensor_lengths * repeat_lengths);
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
|
||||
#if 0
|
||||
constexpr auto repeat_multi_id = sequence2array(decltype(repeat_multi_id_){});
|
||||
|
||||
const auto clipboard_data_multi_id_begin = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
const auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
|
||||
|
||||
const index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
|
||||
const index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin);
|
||||
#else // HIP compiler performs better with these codes
|
||||
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
|
||||
|
||||
constexpr auto clipboard_data_multi_id_begin =
|
||||
repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
constexpr auto dst_data_multi_id_begin = repeat_multi_id * data_per_cluster_per_dims;
|
||||
|
||||
constexpr index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id_begin);
|
||||
|
||||
constexpr index_t dst_offset =
|
||||
DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id_begin);
|
||||
#endif
|
||||
|
||||
threadwise_generic_tensor_slice_copy_v1(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
DstDesc{},
|
||||
p_dst + dst_offset + mThreadDstOffset,
|
||||
make_zero_array<index_t, nDim>(),
|
||||
thread_sub_tensor_lengths,
|
||||
DstAccessOrder{},
|
||||
Number<DstDataPerWrite>{});
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
Float p_clipboard[GetRegisterClipboardSize()];
|
||||
|
||||
RunLoadRegisterClipboard(p_src, p_clipboard);
|
||||
RunStoreRegisterClipboard(p_clipboard, p_dst);
|
||||
}
|
||||
|
||||
// When moving the slicing windows along a merged dimension, if the strides of the
|
||||
// contained (by the merged dimension) original dimensions are in descending order,
|
||||
// then there is no guarantee that the new offset will be larger than the old offset
|
||||
// for movement in positive direction (vice versue for movement in negative direction).
|
||||
// As a result, there is the possiblity that the offset calculation may result in
|
||||
// unsigned integer underflow (due to "-" operation). However, this hazard should not
|
||||
// happen, as long as the users make sure the slicing window would not be moved out of
|
||||
// the boundary of the tensor being sliced. This functions doesn't do runtime sanity
|
||||
// check on out-of-bound slicing window, for performance reason
|
||||
template <index_t IDim_, index_t StepSize, bool PositiveDirection>
|
||||
__device__ void MoveSlicingWindowOnSourceTensor(
|
||||
Number<IDim_>, Number<StepSize>, integral_constant<bool, PositiveDirection> direction)
|
||||
{
|
||||
constexpr auto IDim = Number<IDim_>{};
|
||||
constexpr index_t idim = IDim.Get();
|
||||
|
||||
static_if<SrcDesc::ContainMultipleOriginalDimensions(IDim)>{}([&](auto fwd) {
|
||||
// logic for a merged dimension, also works for non-merged dimension, but its logic may
|
||||
// be unncessarily complicated for compiler to remove calculations that are useless for
|
||||
// a non-merged dimension
|
||||
|
||||
// extract partial original dimensions
|
||||
constexpr auto src_partial_original_dims =
|
||||
SrcDesc::GetContainedOriginalDimensions(IDim);
|
||||
|
||||
constexpr auto src_partial_original_desc =
|
||||
SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims);
|
||||
|
||||
// calculate new partial original multi-id
|
||||
auto old_src_partial_original_multi_id =
|
||||
extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims);
|
||||
|
||||
auto new_src_partial_original_multi_id =
|
||||
src_partial_original_desc.UpdateMultiIndexGivenStepSizeOf1dIndex(
|
||||
old_src_partial_original_multi_id, StepSize, direction);
|
||||
|
||||
// update "mThreadSrcOriginalMultiId"
|
||||
static_for<0, decltype(src_partial_original_dims)::GetSize(), 1>{}([&](auto I_) {
|
||||
constexpr auto I = decltype(I_){};
|
||||
constexpr index_t idim_original = src_partial_original_dims.Get(I);
|
||||
|
||||
mThreadSrcOriginalMultiId(idim_original) =
|
||||
new_src_partial_original_multi_id[I.Get()];
|
||||
});
|
||||
|
||||
// calculate new partial offset on this merged dimension
|
||||
const index_t old_src_partial_offset = mThreadSrcPartialOffsets[idim];
|
||||
|
||||
const index_t new_src_partial_offset =
|
||||
src_partial_original_desc.GetOffsetFromMultiIndex(
|
||||
new_src_partial_original_multi_id);
|
||||
|
||||
// update "mThreadSrcPartialOffsets"
|
||||
mThreadSrcPartialOffsets(idim) = new_src_partial_offset;
|
||||
|
||||
// update "mThreadSrcOffset", do "+" before "-" to avoid underflow
|
||||
mThreadSrcOffset = (mThreadSrcOffset + new_src_partial_offset) - old_src_partial_offset;
|
||||
}).Else([&](auto fwd) {
|
||||
// Logic for non-merged dimension. If you are never going to move the slicing window on
|
||||
// a merged dimension, then "mThreadSrcOriginalMultiId" and "mThreadSrcPartialOffsets",
|
||||
// which are being calculated here, will never be used later. In this case, compiler
|
||||
// should be able to remove these calculations.
|
||||
// TODO: make sure compiler would actually remove them in this case.
|
||||
|
||||
// It is the user's responsiblity to make sure the slicing window will not be moved out
|
||||
// of the boundary of the tensor being sliced. Otherwise, there might be hazard like
|
||||
// unsigned integer underflow. That is NO runtime sanity check to prevent the hazard
|
||||
|
||||
constexpr index_t idim_original = SrcDesc::GetContainedOriginalDimensions(IDim).Front();
|
||||
|
||||
static_if<PositiveDirection>{}([&](auto fwd) {
|
||||
mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
|
||||
mThreadSrcOriginalMultiId(idim_original) += StepSize;
|
||||
|
||||
mThreadSrcPartialOffsets(idim) += StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
}).Else([&](auto fwd) {
|
||||
mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
|
||||
mThreadSrcOriginalMultiId(idim_original) -= StepSize;
|
||||
|
||||
mThreadSrcPartialOffsets(idim) -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,299 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_COPY_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcLengths,
|
||||
class SrcSubLengths,
|
||||
class SrcClusterLengths,
|
||||
class MapDst2Src,
|
||||
class MapThreadCluster2SrcCluster,
|
||||
index_t SrcDataPerRead,
|
||||
index_t DstDataPerWrite>
|
||||
struct BlockwiseTensorSliceReorderCopy_v3
|
||||
{
|
||||
static constexpr index_t nDim = SrcLengths::GetSize();
|
||||
|
||||
index_t mThreadSrcOffset;
|
||||
index_t mThreadDstOffset;
|
||||
|
||||
__device__
|
||||
BlockwiseTensorSliceReorderCopy_v3(Array<index_t, nDim> src_block_data_multi_id_begin,
|
||||
Array<index_t, nDim> dst_block_data_multi_id_begin)
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr auto src_lengths = SrcLengths{};
|
||||
|
||||
constexpr auto map_dst2src = MapDst2Src{};
|
||||
|
||||
constexpr auto src_sub_lengths = SrcSubLengths{};
|
||||
constexpr auto dst_sub_lengths = src_sub_lengths.ReorderGivenNew2Old(map_dst2src);
|
||||
|
||||
constexpr auto map_thread_cluster_2_src_cluster = MapThreadCluster2SrcCluster{};
|
||||
|
||||
constexpr auto src_cluster_lengths = SrcClusterLengths{};
|
||||
constexpr auto thread_cluster_lengths =
|
||||
src_cluster_lengths.ReorderGivenNew2Old(map_thread_cluster_2_src_cluster);
|
||||
|
||||
constexpr auto thread_cluster_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_cluster_lengths);
|
||||
|
||||
// sanity check: data type
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float for now!\n");
|
||||
|
||||
// sanity check: nDim
|
||||
static_assert(SrcDesc::GetNumOfDimension() == nDim &&
|
||||
DstDesc::GetNumOfDimension() == nDim && SrcLengths::GetSize() == nDim &&
|
||||
SrcSubLengths::GetSize() == nDim &&
|
||||
SrcClusterLengths::GetSize() == nDim && MapDst2Src::GetSize() == nDim &&
|
||||
MapThreadCluster2SrcCluster::GetSize() == nDim,
|
||||
"wrong! nDim is not consistent\n");
|
||||
|
||||
// sanity check: BlockSize
|
||||
constexpr index_t num_active_thread = thread_cluster_desc.GetElementSize();
|
||||
|
||||
static_assert(BlockSize >= num_active_thread,
|
||||
"wrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
// sanity check: work division
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr auto I = decltype(IDim){};
|
||||
constexpr index_t src_len = src_lengths.Get(I);
|
||||
constexpr index_t src_sub_len = src_sub_lengths.Get(I);
|
||||
constexpr index_t src_cluster_len = src_cluster_lengths.Get(I);
|
||||
static_assert(src_len % (src_sub_len * src_cluster_len) == 0,
|
||||
"wrong! cannot evenly divide Src tensor lengths");
|
||||
});
|
||||
|
||||
// sanity check: src read
|
||||
static_assert(SrcDataPerRead == 1 || SrcDataPerRead == 2 || SrcDataPerRead == 4,
|
||||
"wrong! only support SrcDataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDataPerRead == 1 || src_desc.GetStride(Number<nDim - 1>{}) == 1,
|
||||
"wrong! only support src.stride(nDim-1) == 1 if SrcDataPerRead > 1!\n");
|
||||
|
||||
static_assert(src_sub_lengths.Get(Number<nDim - 1>{}) % SrcDataPerRead == 0,
|
||||
"wrong! src_sub_lengths[nDim-1] % SrcDataPerRead != 0\n");
|
||||
|
||||
static_assert(src_desc.GetStride(Number<nDim - 2>{}) % SrcDataPerRead == 0,
|
||||
"wrong! should satisfy src_desc.stride(nDim-2) % SrcDataPerRead == 0, to "
|
||||
"keep alignment");
|
||||
|
||||
// sanity check: dst write
|
||||
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4,
|
||||
"wrong! only support DstDataPerWrite == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(DstDataPerWrite == 1 || dst_desc.GetStride(Number<nDim - 1>{}) == 1,
|
||||
"wrong! only support dst.stride(nDim-1) == 1 if DstDataPerWrite > 1!\n");
|
||||
|
||||
static_assert(dst_sub_lengths.Get(Number<nDim - 1>{}) % DstDataPerWrite == 0,
|
||||
"wrong! dst_sub_lengths[nDim-1] % DstDataPerWrite != 0\n");
|
||||
|
||||
static_assert(dst_desc.GetStride(Number<nDim - 2>{}) % DstDataPerWrite == 0,
|
||||
"wrong! should satisfy dst_desc.stride(nDim-2) % DstDataPerWrite == 0, to "
|
||||
"keep alignment");
|
||||
|
||||
// start dividing work
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const auto thread_multi_id =
|
||||
thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id());
|
||||
|
||||
// compiler: thread_multi_id, src_data_multi_id, dst_data_multi_id, will use separate
|
||||
// regsiters, or only one copy???
|
||||
auto src_data_multi_id =
|
||||
reorder_array_given_old2new(thread_multi_id, map_thread_cluster_2_src_cluster);
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr auto I = decltype(IDim){};
|
||||
constexpr index_t i = I.Get();
|
||||
// compiler: will it really compute index here, or be merged with
|
||||
// GetOffsetFromMultiIndex and
|
||||
// optimized away???
|
||||
src_data_multi_id(i) *= src_sub_lengths.Get(I);
|
||||
});
|
||||
|
||||
// compiler: will it really compute index here, or be merged with GetOffsetFromMultiIndex
|
||||
// and
|
||||
// optimized away???
|
||||
const auto dst_data_multi_id = reorder_array_given_new2old(src_data_multi_id, map_dst2src);
|
||||
|
||||
mThreadSrcOffset =
|
||||
src_desc.GetOffsetFromMultiIndex(src_data_multi_id + src_block_data_multi_id_begin);
|
||||
|
||||
mThreadDstOffset =
|
||||
dst_desc.GetOffsetFromMultiIndex(dst_data_multi_id + dst_block_data_multi_id_begin);
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(thread_cluster_desc, "thread_cluster_desc: ");
|
||||
}
|
||||
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("id %5u %5u: "
|
||||
"thread_multi_id: %u %u, "
|
||||
"src_block_data_multi_id_begin: %u %u, "
|
||||
"src_data_multi_id: %u %u, "
|
||||
"mThreadSrcOffset %u, mThreadDstOffset %u \n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
thread_multi_id[0],
|
||||
thread_multi_id[1],
|
||||
src_block_data_multi_id_begin[0],
|
||||
src_block_data_multi_id_begin[1],
|
||||
src_data_multi_id[0],
|
||||
src_data_multi_id[1],
|
||||
mThreadSrcOffset,
|
||||
mThreadDstOffset);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterClipboardSize()
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims =
|
||||
thread_sub_tensor_lengths * SrcClusterLengths{};
|
||||
|
||||
constexpr auto repeat_lengths = transform_sequences(
|
||||
math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
|
||||
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_tensor_lengths);
|
||||
|
||||
return thread_tensor_desc.GetElementSpace();
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims =
|
||||
thread_sub_tensor_lengths * SrcClusterLengths{};
|
||||
|
||||
constexpr auto repeat_lengths = transform_sequences(
|
||||
math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
|
||||
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_tensor_lengths);
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
|
||||
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
|
||||
|
||||
constexpr auto src_data_multi_id = repeat_multi_id * src_data_per_cluster_per_dims;
|
||||
|
||||
constexpr auto clipboard_data_multi_id = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
constexpr index_t src_offset = SrcDesc{}.GetOffsetFromMultiIndex(src_data_multi_id);
|
||||
constexpr index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id);
|
||||
|
||||
threadwise_tensor_slice_copy(SrcDesc{},
|
||||
p_src + src_offset + mThreadSrcOffset,
|
||||
thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
thread_sub_tensor_lengths,
|
||||
Number<SrcDataPerRead>{});
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims =
|
||||
thread_sub_tensor_lengths * SrcClusterLengths{};
|
||||
|
||||
constexpr auto repeat_lengths = transform_sequences(
|
||||
math::integer_divide_ceiler<index_t>{}, SrcLengths{}, src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = thread_sub_tensor_lengths * repeat_lengths;
|
||||
|
||||
constexpr auto thread_tensor_desc =
|
||||
make_ConstantTensorDescriptor_packed(thread_tensor_lengths);
|
||||
|
||||
static_ford<decltype(repeat_lengths)>{}([&](auto repeat_multi_id_) {
|
||||
constexpr auto repeat_multi_id = decltype(repeat_multi_id_){};
|
||||
|
||||
constexpr auto clipboard_data_multi_id = repeat_multi_id * thread_sub_tensor_lengths;
|
||||
|
||||
constexpr auto src_data_multi_id = repeat_multi_id * src_data_per_cluster_per_dims;
|
||||
|
||||
// reorder src_data_multi_id to get dst_data_multi_id
|
||||
constexpr auto dst_data_multi_id = src_data_multi_id.ReorderGivenNew2Old(MapDst2Src{});
|
||||
|
||||
constexpr index_t clipboard_offset =
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id);
|
||||
|
||||
constexpr index_t dst_offset = DstDesc{}.GetOffsetFromMultiIndex(dst_data_multi_id);
|
||||
|
||||
// write in the order of dst
|
||||
#if 1
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
DstDesc{},
|
||||
p_dst + dst_offset +
|
||||
mThreadDstOffset,
|
||||
thread_sub_tensor_lengths,
|
||||
MapDst2Src{});
|
||||
#else
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
DstDesc{},
|
||||
p_dst + dst_offset +
|
||||
mThreadDstOffset,
|
||||
thread_sub_tensor_lengths,
|
||||
MapDst2Src{},
|
||||
Number<DstDataPerWrite>{});
|
||||
#endif
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
Float p_clipboard[GetRegisterClipboardSize()];
|
||||
|
||||
RunLoadRegisterClipboard(p_src, p_clipboard);
|
||||
RunStoreRegisterClipboard(p_clipboard, p_dst);
|
||||
}
|
||||
|
||||
// this function doesn't do santiy check on whether the slicing window is out of the boundary
|
||||
// of the tensor being sliced
|
||||
template <index_t IDim_, index_t StepSize, bool PositiveDirection>
|
||||
__device__ void MoveSlicingWindowOnSourceTensor(
|
||||
Number<IDim_>, Number<StepSize>, integral_constant<bool, PositiveDirection> direction)
|
||||
{
|
||||
constexpr auto IDim = Number<IDim_>{};
|
||||
|
||||
static_if<PositiveDirection>{}([&](auto fwd) {
|
||||
mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
}).Else([&](auto fwd) { mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim); });
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,60 @@
|
||||
#ifndef CK_THREADWISE_4D_TENSOR_OP_HPP
|
||||
#define CK_THREADWISE_4D_TENSOR_OP_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class Float, class Desc, class IDim, class NShift>
|
||||
__device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto desc = Desc{};
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr index_t nshift = NShift::mValue;
|
||||
|
||||
constexpr index_t did0_end =
|
||||
is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
|
||||
|
||||
constexpr index_t did1_end =
|
||||
is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
|
||||
|
||||
constexpr index_t did2_end =
|
||||
is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - nshift : desc.GetLength(I2);
|
||||
|
||||
constexpr index_t did3_end =
|
||||
is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - nshift : desc.GetLength(I3);
|
||||
|
||||
for(index_t did0 = 0; did0 < did0_end; ++did0)
|
||||
{
|
||||
for(index_t did1 = 0; did1 < did1_end; ++did1)
|
||||
{
|
||||
for(index_t did2 = 0; did2 < did2_end; ++did2)
|
||||
{
|
||||
for(index_t did3 = 0; did3 < did3_end; ++did3)
|
||||
{
|
||||
const index_t dindex = desc.GetOffsetFromMultiIndex(did0, did1, did2, did3);
|
||||
|
||||
const index_t sindex = dindex + nshift * desc.GetStride(IDim{});
|
||||
|
||||
p[dindex] = p[sindex];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,228 @@
|
||||
#ifndef CK_THREADWISE_DIRECT_CONVOLUTION_HPP
|
||||
#define CK_THREADWISE_DIRECT_CONVOLUTION_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// optimized for scenario if p_in, p_wei, p_out are in register
|
||||
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
|
||||
__device__ void threadwise_direct_convolution_1(InDesc,
|
||||
TInWei* const __restrict__ p_in,
|
||||
WeiDesc,
|
||||
TInWei* const __restrict__ p_wei,
|
||||
OutDesc,
|
||||
TOut* __restrict__ p_out)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_desc = InDesc{};
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
constexpr auto out_desc = OutDesc{};
|
||||
|
||||
#if 0
|
||||
if(blockIdx.x == 0 && get_thread_local_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_desc, "threadwise_direct_convolution: in_desc: ");
|
||||
print_ConstantTensorDescriptor(wei_desc, "threadwise_direct_convolution: wei_desc: ");
|
||||
print_ConstantTensorDescriptor(out_desc, "threadwise_direct_convolution: out_desc: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
for(index_t n = 0; n < out_desc.GetLength(I0); ++n)
|
||||
{
|
||||
for(index_t k = 0; k < out_desc.GetLength(I1); ++k)
|
||||
{
|
||||
for(index_t ho = 0; ho < out_desc.GetLength(I2); ++ho)
|
||||
{
|
||||
for(index_t wo = 0; wo < out_desc.GetLength(I3); ++wo)
|
||||
{
|
||||
for(index_t c = 0; c < wei_desc.GetLength(I1); ++c)
|
||||
{
|
||||
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
{
|
||||
for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
|
||||
{
|
||||
const index_t hi = ho + y;
|
||||
const index_t wi = wo + x;
|
||||
|
||||
const index_t in_index =
|
||||
in_desc.GetOffsetFromMultiIndex(n, c, hi, wi);
|
||||
|
||||
const index_t wei_index =
|
||||
wei_desc.GetOffsetFromMultiIndex(k, c, y, x);
|
||||
|
||||
const index_t out_index =
|
||||
out_desc.GetOffsetFromMultiIndex(n, k, ho, wo);
|
||||
|
||||
fused_multiply_accumulate(
|
||||
p_out[out_index], p_wei[wei_index], p_in[in_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Optimized for scenario if p_in and p_wei are in LDS, p_out are in register
|
||||
// Copy in and wei into register before doing convolution
|
||||
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
|
||||
__device__ void threadwise_direct_convolution_2(InDesc,
|
||||
TInWei* const __restrict__ p_in,
|
||||
WeiDesc,
|
||||
TInWei* const __restrict__ p_wei,
|
||||
OutDesc,
|
||||
TOut* __restrict__ p_out)
|
||||
{
|
||||
constexpr auto in_desc = InDesc{};
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
constexpr auto out_desc = OutDesc{};
|
||||
|
||||
constexpr auto in_reg_desc = make_ConstantTensorDescriptor_packed(in_desc.GetLengths());
|
||||
constexpr auto wei_reg_desc = make_ConstantTensorDescriptor_packed(wei_desc.GetLengths());
|
||||
|
||||
// register
|
||||
TInWei p_in_reg[in_reg_desc.GetElementSpace()];
|
||||
TInWei p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_tensor_slice_copy(
|
||||
in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths(), Number<1>{});
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_tensor_slice_copy(
|
||||
wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths(), Number<1>{});
|
||||
|
||||
// do convolution
|
||||
threadwise_direct_convolution_1(
|
||||
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
|
||||
}
|
||||
|
||||
// optimized for scenario where p_in and p_wei are in LDS, p_out is in register
|
||||
// break down a non-1x1 convolution into a sequence of 1x1 convolutions,
|
||||
// load 1x1 weight into register, and do 1x1 convolution in register.
|
||||
template <class Data, class InDesc, class WeiDesc, class OutDesc>
|
||||
__device__ void threadwise_direct_convolution_3(InDesc,
|
||||
Data* const __restrict__ p_in,
|
||||
WeiDesc,
|
||||
Data* const __restrict__ p_wei,
|
||||
OutDesc,
|
||||
Data* __restrict__ p_out)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_desc = InDesc{};
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
constexpr auto out_desc = OutDesc{};
|
||||
|
||||
constexpr auto in_reg_desc = make_ConstantTensorDescriptor(Sequence<in_desc.GetLength(I0),
|
||||
in_desc.GetLength(I1),
|
||||
out_desc.GetLength(I2),
|
||||
out_desc.GetLength(I3)>{});
|
||||
|
||||
constexpr auto wei_reg_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<wei_desc.GetLength(I0), wei_desc.GetLength(I1), 1, 1>{});
|
||||
|
||||
Data p_in_reg[in_reg_desc.GetElementSpace()];
|
||||
Data p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
|
||||
constexpr index_t in_w_new_read = 1;
|
||||
|
||||
constexpr auto in_desc_reg_new_read =
|
||||
make_ConstantTensorDescriptor(Sequence<in_reg_desc.GetLength(I0),
|
||||
in_reg_desc.GetLength(I1),
|
||||
in_reg_desc.GetLength(I2),
|
||||
in_w_new_read>{});
|
||||
|
||||
#if 0
|
||||
// this verison reused old input data in register, and read new data from LDS
|
||||
// loop over vertical direction
|
||||
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
{
|
||||
// read first input
|
||||
threadwise_4d_tensor_copy(in_desc,
|
||||
p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, 0),
|
||||
in_reg_desc,
|
||||
p_in_reg,
|
||||
in_reg_desc.GetLengths());
|
||||
|
||||
// read first 1x1 weight
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, 0),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_reg_desc.GetLengths());
|
||||
|
||||
// do first 1x1 conv
|
||||
threadwise_direct_convolution_1(
|
||||
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
|
||||
|
||||
// loop over horizontal direction
|
||||
for(index_t x = 1; x < wei_desc.GetLength(I3); ++x)
|
||||
{
|
||||
// read new weight
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, x),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_reg_desc.GetLengths());
|
||||
|
||||
// shift old input to the left
|
||||
threadwise_4d_tensor_shift_down(in_reg_desc, p_in_reg, I3, Number<in_w_new_read>{});
|
||||
|
||||
// read new input
|
||||
threadwise_4d_tensor_copy(
|
||||
in_desc,
|
||||
p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, x + in_reg_desc.GetLength(I3) - 1),
|
||||
in_reg_desc,
|
||||
p_in_reg +
|
||||
in_reg_desc.GetOffsetFromMultiIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read),
|
||||
in_desc_reg_new_read.GetLengths());
|
||||
|
||||
// do 1x1 conv
|
||||
threadwise_direct_convolution_1(
|
||||
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
|
||||
}
|
||||
}
|
||||
#elif 1
|
||||
// this version read all input from LDS when filter moves
|
||||
// loop over vertical direction
|
||||
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
{
|
||||
// loop over horizontal direction
|
||||
for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
|
||||
{
|
||||
// read new weight
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, x),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_reg_desc.GetLengths());
|
||||
|
||||
// read new input
|
||||
threadwise_4d_tensor_copy(in_desc,
|
||||
p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, x),
|
||||
in_reg_desc,
|
||||
p_in_reg,
|
||||
in_reg_desc.GetLengths());
|
||||
|
||||
// do 1x1 conv
|
||||
threadwise_direct_convolution_1(
|
||||
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
123
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
Normal file
123
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
Normal file
@@ -0,0 +1,123 @@
|
||||
#ifndef CK_THREADWISE_GEMM_HPP
|
||||
#define CK_THREADWISE_GEMM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class Float, class Matrix>
|
||||
__device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread)
|
||||
{
|
||||
for(index_t i = 0; i < Matrix::NRow(); ++i)
|
||||
{
|
||||
for(index_t j = 0; j < Matrix::NCol(); ++j)
|
||||
{
|
||||
const index_t id = Matrix::GetOffsetFromMultiIndex(i, j);
|
||||
p_thread[id] = Float(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Float,
|
||||
class SrcMatrix,
|
||||
class DstMatrix,
|
||||
index_t NRow,
|
||||
index_t NCol,
|
||||
index_t DataPerRead>
|
||||
__device__ void threadwise_matrix_copy(SrcMatrix,
|
||||
const Float* __restrict__ p_src,
|
||||
DstMatrix,
|
||||
Float* __restrict__ p_dst,
|
||||
Sequence<NRow, NCol>,
|
||||
Number<DataPerRead>)
|
||||
{
|
||||
static_assert(NCol % DataPerRead == 0, "wrong! should be NCol % == DataPerRead == 0");
|
||||
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
constexpr auto src_mtx = SrcMatrix{};
|
||||
constexpr auto dst_mtx = DstMatrix{};
|
||||
|
||||
for(index_t i = 0; i < NRow; ++i)
|
||||
{
|
||||
for(index_t j = 0; j < NCol; j += DataPerRead)
|
||||
{
|
||||
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
|
||||
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
|
||||
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class MatrixA,
|
||||
class MatrixB,
|
||||
class MatrixC,
|
||||
bool TransA,
|
||||
bool TransB,
|
||||
bool TransC,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void threadwise_gemm(MatrixA,
|
||||
integral_constant<bool, TransA>,
|
||||
const FloatA* __restrict__ p_a_thread,
|
||||
MatrixB,
|
||||
integral_constant<bool, TransB>,
|
||||
const FloatB* __restrict__ p_b_thread,
|
||||
MatrixC,
|
||||
integral_constant<bool, TransC>,
|
||||
FloatC* __restrict__ p_c_thread)
|
||||
{
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
printf("p_a_thread: %f %f %f %f\n",
|
||||
p_a_thread[0],
|
||||
p_a_thread[1],
|
||||
p_a_thread[2],
|
||||
p_a_thread[3]);
|
||||
printf("p_b_thread: %f %f %f %f\n",
|
||||
p_b_thread[0],
|
||||
p_b_thread[1],
|
||||
p_b_thread[2],
|
||||
p_b_thread[3]);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(TransA && (!TransB) && (!TransC))
|
||||
{
|
||||
constexpr auto a_mtx = MatrixA{};
|
||||
constexpr auto b_mtx = MatrixB{};
|
||||
constexpr auto c_mtx = MatrixC{};
|
||||
|
||||
constexpr index_t M = c_mtx.NRow();
|
||||
constexpr index_t N = c_mtx.NCol();
|
||||
constexpr index_t K = a_mtx.NRow(); // A is transposed
|
||||
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
for(index_t i = 0; i < M; ++i)
|
||||
{
|
||||
for(index_t j = 0; j < N; ++j)
|
||||
{
|
||||
const index_t aindex = a_mtx.GetOffsetFromMultiIndex(k, i); // A is transposed
|
||||
const index_t bindex = b_mtx.GetOffsetFromMultiIndex(k, j);
|
||||
const index_t cindex = c_mtx.GetOffsetFromMultiIndex(i, j);
|
||||
|
||||
p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// not implemented
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,20 @@
|
||||
#ifndef CK_THREADWISE_GENERIC_TENSOR_OP_HPP
|
||||
#define CK_THREADWISE_GENERIC_TENSOR_OP_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
template <class Float, class TDesc>
|
||||
__device__ void threadwise_generic_tensor_set_zero(TDesc, Float* __restrict__ p)
|
||||
{
|
||||
static_ford<decltype(TDesc::GetLengths())>{}([&](auto multi_id) {
|
||||
constexpr index_t offset = TDesc::GetOffsetFromMultiIndex(multi_id);
|
||||
|
||||
p[offset] = static_cast<Float>(0);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,107 @@
|
||||
#ifndef CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP
|
||||
#define CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SliceLengths,
|
||||
class DimAccessOrder,
|
||||
index_t DataPerAccess>
|
||||
__device__ void threadwise_generic_tensor_slice_copy_v1(
|
||||
SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
Array<index_t, SrcDesc::GetNumOfDimension()> src_multi_id_begin,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
Array<index_t, DstDesc::GetNumOfDimension()> dst_multi_id_begin,
|
||||
SliceLengths,
|
||||
DimAccessOrder,
|
||||
Number<DataPerAccess>)
|
||||
{
|
||||
constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() &&
|
||||
nDim == SliceLengths::GetSize() && nDim == DimAccessOrder::GetSize(),
|
||||
"wrong! # of dimensions not the same");
|
||||
|
||||
static_assert(is_valid_sequence_map<DimAccessOrder>::value, "wrong! map is not valid");
|
||||
|
||||
#if 0
|
||||
// doesn't compile, because merged-tensor reordering is not implemented
|
||||
// TODO: implement tensor desc ops for merged-tensor
|
||||
constexpr auto src_strides_in_access_order =
|
||||
SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{});
|
||||
|
||||
constexpr auto dst_strides_in_access_order =
|
||||
SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{});
|
||||
|
||||
// check src/dst stride on the lowest access dimension
|
||||
static_assert((DataPerAccess == 1 || src_strides_in_access_order.Back() == 1) &&
|
||||
(DataPerAccess == 1 || dst_strides_in_access_order.Back() == 1),
|
||||
"wrong! src/dst stride on the lowest access dimension needs to be 1 for "
|
||||
"vectorized read/write");
|
||||
#endif
|
||||
|
||||
constexpr auto slice_lengths_in_access_order =
|
||||
SliceLengths::ReorderGivenNew2Old(DimAccessOrder{});
|
||||
|
||||
// check slice length on the lowest access dimension
|
||||
static_assert(slice_lengths_in_access_order.Back() % DataPerAccess == 0,
|
||||
"wrong! slice length on the lowest access dimension should be evenly divided by "
|
||||
"DataPerAccess");
|
||||
|
||||
constexpr index_t num_access_on_lowest_access_dimension =
|
||||
slice_lengths_in_access_order.Back() / DataPerAccess;
|
||||
|
||||
constexpr auto access_lengths = slice_lengths_in_access_order.Modify(
|
||||
Number<nDim - 1>{}, Number<num_access_on_lowest_access_dimension>{});
|
||||
|
||||
using vector_t = typename vector_type<Float, DataPerAccess>::MemoryType;
|
||||
|
||||
#if 1
|
||||
ford<decltype(access_lengths)>{}([&](auto access_multi_id) {
|
||||
auto data_multi_id_in_access_order = access_multi_id;
|
||||
data_multi_id_in_access_order(nDim - 1) = access_multi_id[nDim - 1] * DataPerAccess;
|
||||
|
||||
const auto data_multi_id =
|
||||
reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{});
|
||||
|
||||
const index_t src_index =
|
||||
SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id);
|
||||
|
||||
const index_t dst_index =
|
||||
DstDesc::GetOffsetFromMultiIndex(dst_multi_id_begin + data_multi_id);
|
||||
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
|
||||
});
|
||||
#else
|
||||
static_ford<decltype(access_lengths)>{}([&](auto access_multi_id) {
|
||||
constexpr index_t itmp = access_multi_id.Back() * DataPerAccess;
|
||||
|
||||
constexpr auto data_multi_id_in_access_order =
|
||||
access_multi_id.Modify(Number<nDim - 1>{}, Number<itmp>{});
|
||||
|
||||
constexpr auto data_multi_id = reorder_array_given_old2new(
|
||||
sequence2array(data_multi_id_in_access_order), DimAccessOrder{});
|
||||
|
||||
const index_t src_index =
|
||||
SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id);
|
||||
|
||||
const index_t dst_index =
|
||||
DstDesc::GetOffsetFromMultiIndex(dst_multi_id_begin + data_multi_id);
|
||||
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,202 @@
|
||||
#ifndef CK_THREADWISE_TENSOR_SLICE_COPY_HPP
|
||||
#define CK_THREADWISE_TENSOR_SLICE_COPY_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// need to assume src and dst is aligned
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
|
||||
__device__ void threadwise_tensor_slice_copy(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
Number<DataPerRead>)
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
constexpr index_t nDim = SrcOpLengths::GetSize();
|
||||
|
||||
static_assert(SrcDesc{}.GetNumOfDimension() == nDim && DstDesc{}.GetNumOfDimension() == nDim,
|
||||
"wrong! dimension not consistent");
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor_packed(SrcOpLengths{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(src_desc, "src_desc");
|
||||
print_ConstantTensorDescriptor(dst_desc, "dst_desc");
|
||||
print_ConstantTensorDescriptor(ref_desc, "ref_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
static_assert(DataPerRead == 1 || (SrcDesc{}.GetStride(Number<nDim - 1>{}) == 1 &&
|
||||
DstDesc{}.GetStride(Number<nDim - 1>{}) == 1),
|
||||
"wrong! only support stride[nDim-1] == 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(
|
||||
SrcDesc{}.GetStride(Number<nDim - 2>{}) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(Number<nDim - 2>{}) % DataPerRead == 0,
|
||||
"wrong! src and dst stride[nDim-2] should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr index_t L_Back = SrcOpLengths{}.Back();
|
||||
|
||||
static_assert(L_Back % DataPerRead == 0,
|
||||
"wrong! lengths[nDim-1] should be evenly divided by DataPerRead");
|
||||
|
||||
constexpr index_t nRead = L_Back / DataPerRead;
|
||||
|
||||
static_ford<decltype(ref_desc.GetLengths().PopBack())>{}([=](auto Ids) {
|
||||
static_for<0, nRead, 1>{}([&](auto IRead) {
|
||||
constexpr auto multi_id = decltype(Ids){}.PushBack(Number<IRead.Get() * DataPerRead>{});
|
||||
|
||||
const index_t src_index = src_desc.GetOffsetFromMultiIndex(multi_id);
|
||||
|
||||
const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(multi_id);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_index])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_src[src_index]));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// access in order of src
|
||||
template <class SrcData,
|
||||
class DstData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v1(SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
ford<SrcOpLengths>{}([&](auto src_multi_id) {
|
||||
const auto dst_multi_id = reorder_array_given_new2old(src_multi_id, MapDst2Src{});
|
||||
|
||||
const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
|
||||
|
||||
const index_t src_index = src_desc.GetOffsetFromMultiIndex(src_multi_id);
|
||||
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
});
|
||||
}
|
||||
|
||||
// access in order of dst
|
||||
template <class SrcData,
|
||||
class DstData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v2(SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr auto dst_op_lengths = SrcOpLengths{}.ReorderGivenNew2Old(MapDst2Src{});
|
||||
|
||||
ford<decltype(dst_op_lengths)>{}([&](auto dst_multi_id) {
|
||||
const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{});
|
||||
|
||||
const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
|
||||
|
||||
const index_t src_index = src_desc.GetOffsetFromMultiIndex(src_multi_id);
|
||||
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
});
|
||||
}
|
||||
|
||||
// access in order of dst
|
||||
// manually pack data into vector before write
|
||||
template <class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src,
|
||||
index_t DstDataPerWrite>
|
||||
__device__ void
|
||||
threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src,
|
||||
Number<DstDataPerWrite>)
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DstDataPerWrite>::MemoryType;
|
||||
|
||||
constexpr index_t nDim = SrcOpLengths::GetSize();
|
||||
|
||||
static_assert(DstDataPerWrite == 1 || DstDesc{}.GetStride(Number<nDim - 1>{}) == 1,
|
||||
"wrong! only support dst.stride[nDim-1] == 1, if DstDataPerWrite != 1");
|
||||
|
||||
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4,
|
||||
"wrong! only support DstDataPerWrite == 1, 2 or 4");
|
||||
|
||||
static_assert(
|
||||
DstDesc{}.GetStride(Number<nDim - 2>{}) % DstDataPerWrite == 0,
|
||||
"wrong! dst.stride[nDim-2] should be multiple of DstDataPerWrite to keep alignment");
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr auto dst_op_lengths = SrcOpLengths{}.ReorderGivenNew2Old(MapDst2Src{});
|
||||
|
||||
constexpr index_t L_Dst_Back = dst_op_lengths.Back();
|
||||
|
||||
static_assert(L_Dst_Back % DstDataPerWrite == 0,
|
||||
"wrong! dst.lengths[nDim-1] should be evenly divided by DstDataPerWrite");
|
||||
|
||||
constexpr index_t nWrite = L_Dst_Back / DstDataPerWrite;
|
||||
|
||||
ford<decltype(dst_op_lengths.PopBack())>{}([&](auto ids) {
|
||||
static_for<0, nWrite, 1>{}([&](auto IWrite) {
|
||||
vector_t dst_vec_data;
|
||||
|
||||
// pack data
|
||||
static_for<0, DstDataPerWrite, 1>{}([&](auto IDstData) {
|
||||
const auto dst_multi_id =
|
||||
ids.PushBack(IWrite.Get() * DstDataPerWrite + IDstData.Get());
|
||||
|
||||
const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{});
|
||||
|
||||
const index_t src_index = src_desc.GetOffsetFromMultiIndex(src_multi_id);
|
||||
|
||||
vector_type<Float, DstDataPerWrite>::SetScalar(
|
||||
dst_vec_data, p_src[src_index], IDstData);
|
||||
});
|
||||
|
||||
// write data
|
||||
const auto dst_multi_id = ids.PushBack(IWrite.Get() * DstDataPerWrite);
|
||||
|
||||
const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_index])) = dst_vec_data;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
377
composable_kernel/include/utility/Array.hpp
Normal file
377
composable_kernel/include/utility/Array.hpp
Normal file
@@ -0,0 +1,377 @@
|
||||
#ifndef CK_ARRAY_HPP
|
||||
#define CK_ARRAY_HPP
|
||||
|
||||
#include "Sequence.hpp"
|
||||
#include "functional2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class TData, index_t NSize>
|
||||
struct Array
|
||||
{
|
||||
using Type = Array<TData, NSize>;
|
||||
|
||||
static constexpr index_t nSize = NSize;
|
||||
|
||||
index_t mData[nSize];
|
||||
|
||||
template <class... Xs>
|
||||
__host__ __device__ constexpr Array(Xs... xs) : mData{static_cast<TData>(xs)...}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t GetSize() const { return NSize; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr TData operator[](Number<I>) const
|
||||
{
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr TData operator[](index_t i) const { return mData[i]; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ TData& operator()(Number<I>)
|
||||
{
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
__host__ __device__ TData& operator()(index_t i) { return mData[i]; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr void Set(Number<I>, TData x)
|
||||
{
|
||||
static_assert(I < NSize, "wrong!");
|
||||
|
||||
mData[I] = x;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr void Set(index_t I, TData x) { mData[I] = x; }
|
||||
|
||||
struct lambda_PushBack // emulate constexpr lambda
|
||||
{
|
||||
const Array<TData, NSize>& old_array;
|
||||
Array<TData, NSize + 1>& new_array;
|
||||
|
||||
__host__ __device__ constexpr lambda_PushBack(const Array<TData, NSize>& old_array_,
|
||||
Array<TData, NSize + 1>& new_array_)
|
||||
: old_array(old_array_), new_array(new_array_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr void operator()(Number<I>) const
|
||||
{
|
||||
new_array.Set(Number<I>{}, old_array[I]);
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ constexpr auto PushBack(TData x) const
|
||||
{
|
||||
Array<TData, NSize + 1> new_array;
|
||||
|
||||
static_for<0, NSize, 1>{}(lambda_PushBack(*this, new_array));
|
||||
|
||||
new_array.Set(Number<NSize>{}, x);
|
||||
|
||||
return new_array;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence2array(Sequence<Is...>)
|
||||
{
|
||||
return Array<index_t, sizeof...(Is)>{Is...};
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto make_zero_array()
|
||||
{
|
||||
constexpr auto zero_sequence = typename uniform_sequence_gen<NSize, 0>::SeqType{};
|
||||
constexpr auto zero_array = sequence2array(zero_sequence);
|
||||
return zero_array;
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, index_t... IRs>
|
||||
__host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
|
||||
Sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(NSize == sizeof...(IRs), "NSize not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
return Array<TData, NSize>{old_array[IRs]...};
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, class MapOld2New>
|
||||
struct lambda_reorder_array_given_old2new
|
||||
{
|
||||
const Array<TData, NSize>& old_array;
|
||||
Array<TData, NSize>& new_array;
|
||||
|
||||
__host__ __device__ constexpr lambda_reorder_array_given_old2new(
|
||||
const Array<TData, NSize>& old_array_, Array<TData, NSize>& new_array_)
|
||||
: old_array(old_array_), new_array(new_array_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IOldDim>
|
||||
__host__ __device__ constexpr void operator()(Number<IOldDim>) const
|
||||
{
|
||||
TData old_data = old_array[IOldDim];
|
||||
|
||||
constexpr index_t INewDim = MapOld2New::Get(Number<IOldDim>{});
|
||||
|
||||
new_array.Set(Number<INewDim>{}, old_data);
|
||||
}
|
||||
};
|
||||
|
||||
template <class TData, index_t NSize, index_t... IRs>
|
||||
__host__ __device__ constexpr auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
|
||||
Sequence<IRs...> /*old2new*/)
|
||||
{
|
||||
Array<TData, NSize> new_array;
|
||||
|
||||
static_assert(NSize == sizeof...(IRs), "NSize not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_reorder_array_given_old2new<TData, NSize, Sequence<IRs...>>(old_array, new_array));
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, class ExtractSeq>
|
||||
__host__ __device__ constexpr auto extract_array(const Array<TData, NSize>& old_array, ExtractSeq)
|
||||
{
|
||||
Array<TData, ExtractSeq::GetSize()> new_array;
|
||||
|
||||
constexpr index_t new_size = ExtractSeq::GetSize();
|
||||
|
||||
static_assert(new_size <= NSize, "wrong! too many extract");
|
||||
|
||||
static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::Get(I)]; });
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
template <class F, class X, class Y, class Z> // emulate constepxr lambda for array math
|
||||
struct lambda_array_math
|
||||
{
|
||||
const F& f;
|
||||
const X& x;
|
||||
const Y& y;
|
||||
Z& z;
|
||||
|
||||
__host__ __device__ constexpr lambda_array_math(const F& f_, const X& x_, const Y& y_, Z& z_)
|
||||
: f(f_), x(x_), y(y_), z(z_)
|
||||
{
|
||||
}
|
||||
|
||||
template <index_t IDim_>
|
||||
__host__ __device__ constexpr void operator()(Number<IDim_>) const
|
||||
{
|
||||
constexpr auto IDim = Number<IDim_>{};
|
||||
|
||||
z.Set(IDim, f(x[IDim], y[IDim]));
|
||||
}
|
||||
};
|
||||
|
||||
// Array = Array + Array
|
||||
template <class TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Array<TData, NSize> b)
|
||||
{
|
||||
Array<TData, NSize> result;
|
||||
|
||||
auto f = math::plus<index_t>{};
|
||||
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Array = Array - Array
|
||||
template <class TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Array<TData, NSize> b)
|
||||
{
|
||||
Array<TData, NSize> result;
|
||||
|
||||
auto f = math::minus<index_t>{};
|
||||
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Array = Array + Sequence
|
||||
template <class TData, index_t NSize, index_t... Is>
|
||||
__host__ __device__ constexpr auto operator+(Array<TData, NSize> a, Sequence<Is...> b)
|
||||
{
|
||||
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
|
||||
|
||||
Array<TData, NSize> result;
|
||||
|
||||
auto f = math::plus<index_t>{};
|
||||
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Array = Array - Sequence
|
||||
template <class TData, index_t NSize, index_t... Is>
|
||||
__host__ __device__ constexpr auto operator-(Array<TData, NSize> a, Sequence<Is...> b)
|
||||
{
|
||||
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
|
||||
|
||||
Array<TData, NSize> result;
|
||||
|
||||
auto f = math::minus<index_t>{};
|
||||
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Array = Array * Sequence
|
||||
template <class TData, index_t NSize, index_t... Is>
|
||||
__host__ __device__ constexpr auto operator*(Array<TData, NSize> a, Sequence<Is...> b)
|
||||
{
|
||||
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
|
||||
|
||||
Array<TData, NSize> result;
|
||||
|
||||
auto f = math::multiplies<index_t>{};
|
||||
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Array = Sequence - Array
|
||||
template <class TData, index_t NSize, index_t... Is>
|
||||
__host__ __device__ constexpr auto operator-(Sequence<Is...> a, Array<TData, NSize> b)
|
||||
{
|
||||
static_assert(sizeof...(Is) == NSize, "wrong! size not the same");
|
||||
|
||||
Array<TData, NSize> result;
|
||||
|
||||
auto f = math::minus<index_t>{};
|
||||
|
||||
static_for<0, NSize, 1>{}(
|
||||
lambda_array_math<decltype(f), decltype(a), decltype(b), decltype(result)>(
|
||||
f, a, b, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, class Reduce>
|
||||
__host__ __device__ constexpr TData
|
||||
accumulate_on_array(const Array<TData, NSize>& a, Reduce f, TData init)
|
||||
{
|
||||
TData result = init;
|
||||
|
||||
static_assert(NSize > 0, "wrong");
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto I) { result = f(result, a[I]); });
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T, index_t NSize>
|
||||
__host__ __device__ void print_Array(const char* s, Array<T, NSize> a)
|
||||
{
|
||||
constexpr index_t nsize = a.GetSize();
|
||||
|
||||
static_assert(nsize > 0 && nsize <= 10, "wrong!");
|
||||
|
||||
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); });
|
||||
|
||||
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); });
|
||||
|
||||
static_if<nsize == 3>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); });
|
||||
|
||||
static_if<nsize == 4>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); });
|
||||
|
||||
static_if<nsize == 5>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]);
|
||||
});
|
||||
|
||||
static_if<nsize == 6>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]);
|
||||
});
|
||||
|
||||
static_if<nsize == 7>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6]);
|
||||
});
|
||||
|
||||
static_if<nsize == 8>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6],
|
||||
a[7]);
|
||||
});
|
||||
|
||||
static_if<nsize == 9>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6],
|
||||
a[7],
|
||||
a[8]);
|
||||
});
|
||||
|
||||
static_if<nsize == 10>{}([&](auto) {
|
||||
printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nsize,
|
||||
a[0],
|
||||
a[1],
|
||||
a[2],
|
||||
a[3],
|
||||
a[4],
|
||||
a[5],
|
||||
a[6],
|
||||
a[7],
|
||||
a[8],
|
||||
a[9]);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
556
composable_kernel/include/utility/Sequence.hpp
Normal file
556
composable_kernel/include/utility/Sequence.hpp
Normal file
@@ -0,0 +1,556 @@
|
||||
#ifndef CK_SEQUENCE_HPP
|
||||
#define CK_SEQUENCE_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "functional.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class Seq>
|
||||
struct is_valid_sequence_map;
|
||||
|
||||
template <index_t... Is>
|
||||
struct Sequence
|
||||
{
|
||||
using Type = Sequence;
|
||||
|
||||
static constexpr index_t mSize = sizeof...(Is);
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSize() { return mSize; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static constexpr index_t Get(Number<I>)
|
||||
{
|
||||
static_assert(I < mSize, "wrong! I too large");
|
||||
|
||||
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
|
||||
const index_t mData[mSize + 1] = {Is..., 0};
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr index_t operator[](Number<I>) const
|
||||
{
|
||||
static_assert(I < mSize, "wrong! I too large");
|
||||
|
||||
const index_t mData[mSize + 1] = {Is..., 0};
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
// make sure I is constepxr
|
||||
__host__ __device__ constexpr index_t operator[](index_t I) const
|
||||
{
|
||||
const index_t mData[mSize + 1] = {Is..., 0};
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
template <index_t... IRs>
|
||||
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs),
|
||||
"wrong! reorder map should have the same size as Sequence to be rerodered");
|
||||
|
||||
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
return Sequence<Type::Get(Number<IRs>{})...>{};
|
||||
}
|
||||
|
||||
#if 0 // require sequence_sort, which is not implemented yet
|
||||
template <class MapOld2New>
|
||||
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/)
|
||||
{
|
||||
static_assert(sizeof...(Is) == MapOld2New::GetSize(),
|
||||
"wrong! reorder map should have the same size as Sequence to be rerodered");
|
||||
|
||||
static_assert(is_valid_sequence_map<MapOld2New>::value,
|
||||
"wrong! invalid reorder map");
|
||||
|
||||
constexpr auto map_new2old = typename sequence_map_inverse<MapOld2New>::SeqMapType{};
|
||||
|
||||
return ReorderGivenNew2Old(map_new2old);
|
||||
}
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr auto Reverse();
|
||||
|
||||
__host__ __device__ static constexpr index_t Front()
|
||||
{
|
||||
const index_t mData[mSize + 1] = {Is..., 0};
|
||||
return mData[0];
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t Back()
|
||||
{
|
||||
const index_t mData[mSize + 1] = {Is..., 0};
|
||||
return mData[mSize - 1];
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static constexpr auto PushFront(Number<I>)
|
||||
{
|
||||
return Sequence<I, Is...>{};
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static constexpr auto PushBack(Number<I>)
|
||||
{
|
||||
return Sequence<Is..., I>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto PopFront();
|
||||
|
||||
__host__ __device__ static constexpr auto PopBack();
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ static constexpr auto Append(Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<Is..., Xs...>{};
|
||||
}
|
||||
|
||||
template <index_t... Ns>
|
||||
__host__ __device__ static constexpr auto Extract(Number<Ns>...)
|
||||
{
|
||||
return Sequence<Type::Get(Number<Ns>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t... Ns>
|
||||
__host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
|
||||
{
|
||||
return Sequence<Type::Get(Number<Ns>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t I, index_t X>
|
||||
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>);
|
||||
};
|
||||
|
||||
// merge sequence
|
||||
template <class, class>
|
||||
struct sequence_merge;
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
|
||||
{
|
||||
using SeqType = Sequence<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
// arithmetic sqeuence
|
||||
template <index_t IBegin, index_t NSize, index_t Increment>
|
||||
struct arithmetic_sequence_gen_impl
|
||||
{
|
||||
static constexpr index_t NSizeLeft = NSize / 2;
|
||||
|
||||
using SeqType = typename sequence_merge<
|
||||
typename arithmetic_sequence_gen_impl<IBegin, NSizeLeft, Increment>::SeqType,
|
||||
typename arithmetic_sequence_gen_impl<IBegin + NSizeLeft * Increment,
|
||||
NSize - NSizeLeft,
|
||||
Increment>::SeqType>::SeqType;
|
||||
};
|
||||
|
||||
template <index_t IBegin, index_t Increment>
|
||||
struct arithmetic_sequence_gen_impl<IBegin, 1, Increment>
|
||||
{
|
||||
using SeqType = Sequence<IBegin>;
|
||||
};
|
||||
|
||||
template <index_t IBegin, index_t Increment>
|
||||
struct arithmetic_sequence_gen_impl<IBegin, 0, Increment>
|
||||
{
|
||||
using SeqType = Sequence<>;
|
||||
};
|
||||
|
||||
template <index_t IBegin, index_t IEnd, index_t Increment>
|
||||
struct arithmetic_sequence_gen
|
||||
{
|
||||
using SeqType =
|
||||
typename arithmetic_sequence_gen_impl<IBegin, IEnd - IBegin, Increment>::SeqType;
|
||||
};
|
||||
|
||||
// transform sequence
|
||||
template <class, class>
|
||||
struct sequence_transform;
|
||||
|
||||
template <class F, index_t... Is>
|
||||
struct sequence_transform<F, Sequence<Is...>>
|
||||
{
|
||||
using SeqType = Sequence<F{}(Is)...>;
|
||||
};
|
||||
|
||||
// uniform sequence
|
||||
template <index_t NSize, index_t I>
|
||||
struct uniform_sequence_gen
|
||||
{
|
||||
struct return_constant
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t) const { return I; }
|
||||
};
|
||||
|
||||
using SeqType = typename sequence_transform<
|
||||
return_constant,
|
||||
typename arithmetic_sequence_gen<0, NSize, 1>::SeqType>::SeqType;
|
||||
};
|
||||
|
||||
// reverse inclusive scan (with init) sequence
|
||||
template <class, class, index_t>
|
||||
struct sequence_reverse_inclusive_scan;
|
||||
|
||||
template <index_t I, index_t... Is, class Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
|
||||
{
|
||||
using old_scan =
|
||||
typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::SeqType;
|
||||
|
||||
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
|
||||
|
||||
using SeqType = typename sequence_merge<Sequence<new_reduce>, old_scan>::SeqType;
|
||||
};
|
||||
|
||||
template <index_t I, class Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
|
||||
{
|
||||
using SeqType = Sequence<Reduce{}(I, Init)>;
|
||||
};
|
||||
|
||||
template <class Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
|
||||
{
|
||||
using SeqType = Sequence<>;
|
||||
};
|
||||
|
||||
// extract sequence
|
||||
template <class, class>
|
||||
struct sequence_extract;
|
||||
|
||||
template <class Seq, index_t... Is>
|
||||
struct sequence_extract<Seq, Sequence<Is...>>
|
||||
{
|
||||
using SeqType = Sequence<Seq{}.Get(Number<Is>{})...>;
|
||||
};
|
||||
|
||||
// split sequence
|
||||
template <class Seq, index_t I>
|
||||
struct sequence_split
|
||||
{
|
||||
static constexpr index_t NSize = Seq{}.GetSize();
|
||||
|
||||
using range0 = typename arithmetic_sequence_gen<0, I, 1>::SeqType;
|
||||
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::SeqType;
|
||||
|
||||
using SeqType0 = typename sequence_extract<Seq, range0>::SeqType;
|
||||
using SeqType1 = typename sequence_extract<Seq, range1>::SeqType;
|
||||
};
|
||||
|
||||
// reverse sequence
|
||||
template <class Seq>
|
||||
struct sequence_reverse
|
||||
{
|
||||
static constexpr index_t NSize = Seq{}.GetSize();
|
||||
|
||||
using seq_split = sequence_split<Seq, NSize / 2>;
|
||||
using SeqType = typename sequence_merge<
|
||||
typename sequence_reverse<typename seq_split::SeqType1>::SeqType,
|
||||
typename sequence_reverse<typename seq_split::SeqType0>::SeqType>::SeqType;
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
struct sequence_reverse<Sequence<I>>
|
||||
{
|
||||
using SeqType = Sequence<I>;
|
||||
};
|
||||
|
||||
template <index_t I0, index_t I1>
|
||||
struct sequence_reverse<Sequence<I0, I1>>
|
||||
{
|
||||
using SeqType = Sequence<I1, I0>;
|
||||
};
|
||||
|
||||
#if 0 // not fully implemented
|
||||
template <class KeySeq0, class ValSeq0, class KeySeq1, class ValSeq1>
|
||||
struct sequence_sort_merge_impl;
|
||||
|
||||
template <index_t Key0,
|
||||
index_t... Keys0,
|
||||
index_t Val0,
|
||||
index_t... Vals0,
|
||||
index_t Key1,
|
||||
index_t... Keys1,
|
||||
index_t Val0,
|
||||
index_t... Vals1>
|
||||
struct sequence_sort_merge_impl<Sequence<Key0, Keys0...>,
|
||||
Sequence<Val0, Vals0...>,
|
||||
Sequence<Key1, Keys1...>,
|
||||
Sequence<Val1, Vals1...>>
|
||||
{
|
||||
};
|
||||
|
||||
template <class>
|
||||
struct sequence_sort;
|
||||
|
||||
template <index_t... Is>
|
||||
struct sequence_sort<Sequence<Is...>>
|
||||
{
|
||||
using OriginalSeqType = Sequence<Is...>;
|
||||
using SortedSeqType = xxxxx;
|
||||
using MapSorted2OriginalType = xxx;
|
||||
};
|
||||
|
||||
template <class Seq, class IsValidSeqMap>
|
||||
struct sequence_map_inverse_impl;
|
||||
|
||||
// impl for valid map, no impl for invalid map
|
||||
template <index_t... Is>
|
||||
struct sequence_map_inverse_impl<Sequence<Is...>, true>
|
||||
{
|
||||
using SeqMapType = sequence_sort<Sequence<Is...>>::MapSorted2OriginalType;
|
||||
};
|
||||
|
||||
template <class>
|
||||
struct sequence_map_inverse;
|
||||
|
||||
template <class Is...>
|
||||
struct sequence_map_inverse<Sequence<Is...>>
|
||||
{
|
||||
// TODO: make sure the map to be inversed is valid: [0, sizeof...(Is))
|
||||
static constexpr bool is_valid_sequence_map =
|
||||
is_same<typename sequence_sort<Sequence<Is...>>::SortedSeqType,
|
||||
typename arithmetic_sequence_gen<0, sizeof...(Is), 1>::SeqType>::value;
|
||||
|
||||
// make compiler fails, if is_valid_map != true
|
||||
using SeqMapType =
|
||||
typename sequence_map_inverse_impl<Sequence<Is...>, is_valid_map>::SeqMapType;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
template <class Seq>
|
||||
struct is_valid_sequence_map
|
||||
{
|
||||
static constexpr bool value =
|
||||
#if 0 // sequence_sort is not implemented yet
|
||||
is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::SeqType,
|
||||
typename sequence_sort<Seq>::SortedSeqType>::value;
|
||||
#else
|
||||
true;
|
||||
#endif
|
||||
};
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
return Sequence<(Xs + Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator-(Sequence<Xs...> seq_x, Sequence<Ys...> seq_y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
return Sequence<(Xs - Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
return Sequence<(Xs * Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
return Sequence<(Xs / Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
return Sequence<(Xs % Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
|
||||
{
|
||||
return Sequence<(Xs + Y)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
|
||||
{
|
||||
return Sequence<(Xs - Y)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
|
||||
{
|
||||
return Sequence<(Xs * Y)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
|
||||
{
|
||||
return Sequence<(Xs / Y)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
|
||||
{
|
||||
return Sequence<(Xs % Y)...>{};
|
||||
}
|
||||
|
||||
template <index_t Y, index_t... Xs>
|
||||
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<(Y + Xs)...>{};
|
||||
}
|
||||
|
||||
template <index_t Y, index_t... Xs>
|
||||
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
|
||||
{
|
||||
constexpr auto seq_x = Sequence<Xs...>{};
|
||||
|
||||
return Sequence<(Y - Xs)...>{};
|
||||
}
|
||||
|
||||
template <index_t Y, index_t... Xs>
|
||||
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<(Y * Xs)...>{};
|
||||
}
|
||||
|
||||
template <index_t Y, index_t... Xs>
|
||||
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<(Y / Xs)...>{};
|
||||
}
|
||||
|
||||
template <index_t Y, index_t... Xs>
|
||||
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<(Y % Xs)...>{};
|
||||
}
|
||||
|
||||
template <index_t I, index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
|
||||
{
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
|
||||
template <class Seq>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Seq)
|
||||
{
|
||||
static_assert(Seq{}.GetSize() > 0, "wrong! cannot pop an empty Sequence!");
|
||||
return sequence_pop_front(Seq{}.Reverse()).Reverse();
|
||||
}
|
||||
|
||||
template <class F, index_t... Xs>
|
||||
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<f(Xs)...>{};
|
||||
}
|
||||
|
||||
template <class F, index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
|
||||
|
||||
return Sequence<f(Xs, Ys)...>{};
|
||||
}
|
||||
|
||||
template <class F, index_t... Xs, index_t... Ys, index_t... Zs>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
|
||||
{
|
||||
static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
|
||||
Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
|
||||
"Dim not the same");
|
||||
|
||||
return Sequence<f(Xs, Ys, Zs)...>{};
|
||||
}
|
||||
|
||||
template <class Seq, class Reduce, index_t Init>
|
||||
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
|
||||
{
|
||||
return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::SeqType{};
|
||||
}
|
||||
|
||||
template <class Seq, class Reduce, index_t Init>
|
||||
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto Sequence<Is...>::PopFront()
|
||||
{
|
||||
return sequence_pop_front(Type{});
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto Sequence<Is...>::PopBack()
|
||||
{
|
||||
return sequence_pop_back(Type{});
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto Sequence<Is...>::Reverse()
|
||||
{
|
||||
return typename sequence_reverse<Sequence<Is...>>::SeqType{};
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
template <index_t I, index_t X>
|
||||
__host__ __device__ constexpr auto Sequence<Is...>::Modify(Number<I>, Number<X>)
|
||||
{
|
||||
static_assert(I < GetSize(), "wrong!");
|
||||
|
||||
using seq_split = sequence_split<Type, I>;
|
||||
constexpr auto seq_left = typename seq_split::SeqType0{};
|
||||
constexpr auto seq_right = typename seq_split::SeqType1{}.PopFront();
|
||||
|
||||
return seq_left.PushBack(Number<X>{}).Append(seq_right);
|
||||
}
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ void print_Sequence(const char* s, Sequence<Xs...>)
|
||||
{
|
||||
constexpr index_t nsize = Sequence<Xs...>::GetSize();
|
||||
|
||||
static_assert(nsize <= 10, "wrong!");
|
||||
|
||||
static_if<nsize == 0>{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 1>{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 2>{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 3>{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 4>{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 5>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 6>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 7>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 8>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 9>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
|
||||
static_if<nsize == 10>{}(
|
||||
[&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); });
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
768
composable_kernel/include/utility/amd_inline_asm.hpp
Normal file
768
composable_kernel/include/utility/amd_inline_asm.hpp
Normal file
@@ -0,0 +1,768 @@
|
||||
#ifndef CK_AMD_INLINE_ASM_HPP
|
||||
#define CK_AMD_INLINE_ASM_HPP
|
||||
|
||||
#include "vector_type.hpp"
|
||||
|
||||
#define NO_VM_WAIT 0
|
||||
#define NO_LGKM_WAIT 0
|
||||
#define NO_DS_READ 0
|
||||
#define NO_DS_WRITE 0
|
||||
#define NO_GLB_READ 0
|
||||
|
||||
namespace ck {
|
||||
|
||||
// cast a pointer of LDS to its address
|
||||
extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]];
|
||||
|
||||
__device__ void vmcnt(index_t cnt)
|
||||
{
|
||||
#if !NO_VM_WAIT
|
||||
if(cnt == 0)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(0) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 1)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(1) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 2)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(2) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 4)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(2) \n \
|
||||
" ::);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void lgkmcnt(index_t cnt)
|
||||
{
|
||||
#if !NO_LGKM_WAIT
|
||||
if(cnt == 0)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 1)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt lgkmcnt(1) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 2)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt lgkmcnt(2) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 3)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt lgkmcnt(3) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 4)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt lgkmcnt(4) \n \
|
||||
" ::);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void outerProduct1x4(const float* a, const float* b, float* c)
|
||||
{
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %4, %5 \n \
|
||||
v_mac_f32 %1, %4, %6 \n \
|
||||
v_mac_f32 %2, %4, %7 \n \
|
||||
v_mac_f32 %3, %4, %8 \n \
|
||||
"
|
||||
: "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3])
|
||||
: "v"(a[0]),
|
||||
"v"(b[0]),
|
||||
"v"(b[1]),
|
||||
"v"(b[2]),
|
||||
"v"(b[3]),
|
||||
"0"(c[0]),
|
||||
"1"(c[1]),
|
||||
"2"(c[2]),
|
||||
"3"(c[3]));
|
||||
}
|
||||
|
||||
__device__ void outerProduct1x4(const float& a,
|
||||
const vector_type<float, 4>::MemoryType& b,
|
||||
vector_type<float, 4>::MemoryType& c)
|
||||
{
|
||||
#if 0
|
||||
asm volatile(
|
||||
"\n \
|
||||
v_mac_f32 %0, %4, %5 \n \
|
||||
v_mac_f32 %1, %4, %6 \n \
|
||||
v_mac_f32 %2, %4, %7 \n \
|
||||
v_mac_f32 %3, %4, %8 \n \
|
||||
"
|
||||
:
|
||||
:"v"(c.x),"v"(c.y),"v"(c.z),"v"(c.w), \
|
||||
"v"(a.x),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
|
||||
);
|
||||
#else
|
||||
outerProduct1x4(&a, (float*)&b, (float*)&c);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void outerProduct4x4(const vector_type<float, 4>::MemoryType& a,
|
||||
const vector_type<float, 4>::MemoryType& b,
|
||||
vector_type<float, 4>::MemoryType& c0,
|
||||
vector_type<float, 4>::MemoryType& c1,
|
||||
vector_type<float, 4>::MemoryType& c2,
|
||||
vector_type<float, 4>::MemoryType& c3)
|
||||
{
|
||||
#if 0
|
||||
asm volatile(
|
||||
"\n \
|
||||
v_mac_f32 %0, %4, %5 \n \
|
||||
v_mac_f32 %1, %4, %6 \n \
|
||||
v_mac_f32 %2, %4, %7 \n \
|
||||
v_mac_f32 %3, %4, %8 \n \
|
||||
"
|
||||
:
|
||||
:"v"(c0.x),"v"(c0.y),"v"(c0.z),"v"(c0.w), \
|
||||
"v"(a.x),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
|
||||
);
|
||||
asm volatile(
|
||||
"\n \
|
||||
v_mac_f32 %0, %4, %5 \n \
|
||||
v_mac_f32 %1, %4, %6 \n \
|
||||
v_mac_f32 %2, %4, %7 \n \
|
||||
v_mac_f32 %3, %4, %8 \n \
|
||||
"
|
||||
:
|
||||
:"v"(c1.x),"v"(c1.y),"v"(c1.z),"v"(c1.w), \
|
||||
"v"(a.y),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
|
||||
);
|
||||
asm volatile(
|
||||
"\n \
|
||||
v_mac_f32 %0, %4, %5 \n \
|
||||
v_mac_f32 %1, %4, %6 \n \
|
||||
v_mac_f32 %2, %4, %7 \n \
|
||||
v_mac_f32 %3, %4, %8 \n \
|
||||
"
|
||||
:
|
||||
:"v"(c2.x),"v"(c2.y),"v"(c2.z),"v"(c2.w), \
|
||||
"v"(a.z),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
|
||||
);
|
||||
asm volatile(
|
||||
"\n \
|
||||
v_mac_f32 %0, %4, %5 \n \
|
||||
v_mac_f32 %1, %4, %6 \n \
|
||||
v_mac_f32 %2, %4, %7 \n \
|
||||
v_mac_f32 %3, %4, %8 \n \
|
||||
"
|
||||
:
|
||||
:"v"(c3.x),"v"(c3.y),"v"(c3.z),"v"(c3.w), \
|
||||
"v"(a.w),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
|
||||
);
|
||||
#else
|
||||
outerProduct1x4(a.x, b, c0);
|
||||
outerProduct1x4(a.y, b, c1);
|
||||
outerProduct1x4(a.z, b, c2);
|
||||
outerProduct1x4(a.w, b, c3);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void outerProduct8x8(const vector_type<float, 4>::MemoryType* a,
|
||||
const vector_type<float, 4>::MemoryType* b,
|
||||
vector_type<float, 4>::MemoryType* c)
|
||||
{
|
||||
outerProduct4x4(a[0], b[0], c[0], c[2], c[4], c[6]);
|
||||
outerProduct4x4(a[0], b[1], c[1], c[3], c[5], c[7]);
|
||||
outerProduct4x4(a[1], b[0], c[8], c[10], c[12], c[14]);
|
||||
outerProduct4x4(a[1], b[1], c[9], c[11], c[13], c[15]);
|
||||
}
|
||||
|
||||
__device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, index_t offset = 0)
|
||||
{
|
||||
#if !NO_DS_READ
|
||||
if(offset == 0)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:0\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 64)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:64\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 128)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:128\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 192)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:192\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 256)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:256\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 320)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:320\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 384)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:384\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 448)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:448\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 512)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:512\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 576)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:576\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 640)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:640\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 704)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:704\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 768)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:768\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 832)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:832\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 896)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:896\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 960)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:960\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1024)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1024\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1088)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1088\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1152)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1152\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1216)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1216\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1280)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1280\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1344)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1344\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1408)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1408\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1472)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1472\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1536)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1536\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1600)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1600\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1664)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1664\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1728)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1728\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1792)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1792\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1856)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1856\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1920)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1920\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1984)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1984\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2048)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2048\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2112)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2112\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2176)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2176\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2240)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2240\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2304)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2304\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2368)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2368\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2432)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2432\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2496)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2496\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2560)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2560\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2624)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2624\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2688)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2688\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2752)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2752\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2816)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2816\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2880)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2880\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2944)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2944\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3008)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3008\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3072)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3072\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3136)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3136\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3200)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3200\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3264)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3264\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3328)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3328\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3392)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3392\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3456)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3456\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3520)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3520\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3584)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3584\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3648)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3648\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3712)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3712\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3776)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3776\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3840)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3840\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3904)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3904\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3968)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3968\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 4032)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:4032\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 4096)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:4096\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void global_load(vector_type<float, 4>::MemoryType& r,
|
||||
const vector_type<float, 4>::MemoryType* ptr,
|
||||
index_t offset = 0)
|
||||
{
|
||||
#if !NO_GLB_READ
|
||||
if(offset == 0)
|
||||
{
|
||||
asm volatile("\n \
|
||||
global_load_dwordx4 %0, %1, off \n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(ptr));
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void
|
||||
ds_write_b128(const vector_type<float, 4>::MemoryType& r, void* lds, index_t offset = 0)
|
||||
{
|
||||
#if !NO_DS_WRITE
|
||||
if(offset == 0)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_write_b128 %0, %1 \n \
|
||||
"
|
||||
:
|
||||
: "v"(__to_local(lds)), "v"(r));
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
18
composable_kernel/include/utility/common_header.hpp
Normal file
18
composable_kernel/include/utility/common_header.hpp
Normal file
@@ -0,0 +1,18 @@
|
||||
#ifndef CK_COMMON_HPP
|
||||
#define CK_COMMON_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
#include "utility.hpp"
|
||||
#include "vector_type.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "Sequence.hpp"
|
||||
#include "Array.hpp"
|
||||
#include "functional.hpp"
|
||||
#include "functional2.hpp"
|
||||
#include "functional3.hpp"
|
||||
|
||||
#if CK_USE_AMD_INLINE_ASM
|
||||
#include "amd_inline_asm.hpp"
|
||||
#endif
|
||||
|
||||
#endif
|
||||
41
composable_kernel/include/utility/config_amd.hpp.in
Normal file
41
composable_kernel/include/utility/config_amd.hpp.in
Normal file
@@ -0,0 +1,41 @@
|
||||
#ifndef CK_CONFIG_AMD_HPP
|
||||
#define CK_CONFIG_AMD_HPP
|
||||
|
||||
#cmakedefine01 CK_DEVICE_BACKEND_AMD
|
||||
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
#define CK_USE_AMD_INLINE_ASM 1
|
||||
|
||||
namespace ck {
|
||||
|
||||
// For some reason, HIP compiler need this definition to generate optimal load and store
|
||||
// instruction
|
||||
typedef float float2_t __attribute__((ext_vector_type(2)));
|
||||
typedef float float4_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
using index_t = uint32_t;
|
||||
|
||||
__device__ void fused_multiply_accumulate(float& d, const float& s0, const float& s1)
|
||||
{
|
||||
d += s0 * s1;
|
||||
}
|
||||
|
||||
#if 0
|
||||
__device__ void fused_multiply_accumulate(half& d, const half& s0, const half& s1) { d += s0 * s1; }
|
||||
|
||||
__device__ void fused_multiply_accumulate(half& d, const half2& s0, const half2& s1)
|
||||
{
|
||||
d += s0.x * s1.x;
|
||||
d += s0.y * s1.y;
|
||||
}
|
||||
|
||||
__device__ void fused_multiply_accumulate(float& d, const half2& s0, const half2& s1)
|
||||
{
|
||||
d += s0.x * s1.x + s0.y * s1.y;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
56
composable_kernel/include/utility/config_nvidia.hpp.in
Normal file
56
composable_kernel/include/utility/config_nvidia.hpp.in
Normal file
@@ -0,0 +1,56 @@
|
||||
#ifndef CK_CONFIG_NVIDIA_HPP
|
||||
#define CK_CONFIG_NVIDIA_HPP
|
||||
|
||||
#cmakedefine01 CK_DEVICE_BACKEND_NVIDIA
|
||||
|
||||
#include "cuda_runtime.h"
|
||||
#include "cuda_fp16.h"
|
||||
#include "nvToolsExt.h"
|
||||
#include "helper_cuda.h"
|
||||
#define CK_USE_AMD_INLINE_ASM 0
|
||||
|
||||
namespace ck {
|
||||
|
||||
// For some reason, CUDA need this definition, otherwise
|
||||
// compiler won't generate optimal load and store instruction, and
|
||||
// kernel would produce wrong result, indicating the compiler fail to generate correct
|
||||
// instruction,
|
||||
using float2_t = float2;
|
||||
using float4_t = float4;
|
||||
|
||||
using index_t = uint32_t;
|
||||
|
||||
__device__ void fused_multiply_accumulate(float& d, const float& s0, const float& s1)
|
||||
{
|
||||
d += s0 * s1;
|
||||
}
|
||||
|
||||
#if 0
|
||||
__device__ void fused_multiply_accumulate(half& d, const half& s0, const half& s1) { d += s0 * s1; }
|
||||
|
||||
__device__ void fused_multiply_accumulate(half& d, const half2& s0, const half2& s1)
|
||||
{
|
||||
d += s0.x * s1.x;
|
||||
d += s0.y * s1.y;
|
||||
}
|
||||
|
||||
__device__ void fused_multiply_accumulate(float& d, const half2& s0, const half2& s1)
|
||||
{
|
||||
d += s0.x * s1.x + s0.y * s1.y;
|
||||
}
|
||||
|
||||
__device__ void fused_multiply_accumulate(char& d, const char& s0, const char& s1) { d += s0 * s1; }
|
||||
|
||||
// TODO:: this interface is misleading, s0, s1 are actually int8x4
|
||||
// need to make a better interface
|
||||
__device__ void fused_multiply_accumulate(int32_t& d, const int32_t& s0, const int32_t& s1)
|
||||
{
|
||||
#if CK_DEVICE_BACKEND_NVIDIA
|
||||
d = __dp4a(s0, s1, d);
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
79
composable_kernel/include/utility/functional.hpp
Normal file
79
composable_kernel/include/utility/functional.hpp
Normal file
@@ -0,0 +1,79 @@
|
||||
#ifndef CK_FUNCTIONAL_HPP
|
||||
#define CK_FUNCTIONAL_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "Sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
struct forwarder
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T&& operator()(T&& x) const
|
||||
{
|
||||
return static_cast<T&&>(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct swallow
|
||||
{
|
||||
template <class... Ts>
|
||||
__host__ __device__ constexpr swallow(Ts&&... ts)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
// Emulate if constexpr
|
||||
template <bool Predicate>
|
||||
struct static_if
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_if<true>
|
||||
{
|
||||
using Type = static_if<true>;
|
||||
|
||||
template <class F>
|
||||
__host__ __device__ constexpr auto operator()(F f) const
|
||||
{
|
||||
// This is a trick for compiler:
|
||||
// Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will use it,
|
||||
// this will make "f" a generic lambda, so that "f" won't be compiled until being
|
||||
// instantiated here
|
||||
f(forwarder{});
|
||||
return Type{};
|
||||
}
|
||||
|
||||
template <class F>
|
||||
__host__ __device__ static constexpr auto Else(F)
|
||||
{
|
||||
return Type{};
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_if<false>
|
||||
{
|
||||
using Type = static_if<false>;
|
||||
|
||||
template <class F>
|
||||
__host__ __device__ constexpr auto operator()(F) const
|
||||
{
|
||||
return Type{};
|
||||
}
|
||||
|
||||
template <class F>
|
||||
__host__ __device__ static constexpr auto Else(F f)
|
||||
{
|
||||
// This is a trick for compiler:
|
||||
// Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will use it,
|
||||
// this will make "f" a generic lambda, so that "f" won't be compiled until being
|
||||
// instantiated here
|
||||
f(forwarder{});
|
||||
return Type{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
68
composable_kernel/include/utility/functional2.hpp
Normal file
68
composable_kernel/include/utility/functional2.hpp
Normal file
@@ -0,0 +1,68 @@
|
||||
#ifndef CK_FUNCTIONAL2_HPP
|
||||
#define CK_FUNCTIONAL2_HPP
|
||||
|
||||
#include "functional.hpp"
|
||||
#include "Sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class>
|
||||
struct static_for_impl;
|
||||
|
||||
template <index_t... Is>
|
||||
struct static_for_impl<Sequence<Is...>>
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
swallow{(f(Number<Is>{}), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
// F signature: F(Number<Iter>)
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
static_assert(NBegin <= NEnd, "wrongs! should have NBegin <= NEnd");
|
||||
|
||||
static_assert((NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
|
||||
static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::SeqType>{}(f);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Seq, class Reduce>
|
||||
struct lambda_accumulate_on_sequence
|
||||
{
|
||||
const Reduce& f;
|
||||
index_t& result;
|
||||
|
||||
__host__ __device__ constexpr lambda_accumulate_on_sequence(const Reduce& f_, index_t& result_)
|
||||
: f(f_), result(result_)
|
||||
{
|
||||
}
|
||||
|
||||
template <class IDim>
|
||||
__host__ __device__ constexpr index_t operator()(IDim) const
|
||||
{
|
||||
return result = f(result, Seq::Get(IDim{}));
|
||||
}
|
||||
};
|
||||
|
||||
template <class Seq, class Reduce, index_t Init>
|
||||
__host__ __device__ constexpr index_t
|
||||
accumulate_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
|
||||
{
|
||||
index_t result = Init;
|
||||
|
||||
static_for<0, Seq::mSize, 1>{}(lambda_accumulate_on_sequence<Seq, Reduce>(f, result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
116
composable_kernel/include/utility/functional3.hpp
Normal file
116
composable_kernel/include/utility/functional3.hpp
Normal file
@@ -0,0 +1,116 @@
|
||||
#ifndef CK_FUNCTIONAL3_HPP
|
||||
#define CK_FUNCTIONAL3_HPP
|
||||
|
||||
#include "functional.hpp"
|
||||
#include "functional2.hpp"
|
||||
#include "Sequence.hpp"
|
||||
#include "Array.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class RemainLengths>
|
||||
struct static_ford_impl
|
||||
{
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// CurrentMultiIndex: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex>
|
||||
__host__ __device__ constexpr void operator()(F f, CurrentMultiIndex) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
|
||||
|
||||
static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
|
||||
static_ford_impl<decltype(RemainLengths::PopFront())>{}(f,
|
||||
CurrentMultiIndex::PushBack(I));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_ford_impl<Sequence<>>
|
||||
{
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// CurrentMultiIndex: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex>
|
||||
__host__ __device__ constexpr void operator()(F f, CurrentMultiIndex) const
|
||||
{
|
||||
f(CurrentMultiIndex{});
|
||||
}
|
||||
};
|
||||
|
||||
// Lengths is Sequence<...>
|
||||
template <class Lengths>
|
||||
struct static_ford
|
||||
{
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
|
||||
|
||||
static_ford_impl<Lengths>{}(f, Sequence<>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t RemainDim>
|
||||
struct ford_impl
|
||||
{
|
||||
// F signature: F(Array<...> multi_id)
|
||||
// CurrentMultiIndex: Array<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == RemainDim, "wrong!");
|
||||
static_assert(RemainDim > 1, "wrong!");
|
||||
|
||||
constexpr auto next_length = RemainLengths{}.Front();
|
||||
|
||||
for(index_t i = 0; i < next_length; ++i)
|
||||
{
|
||||
ford_impl<RemainDim - 1>{}(f, current_multi_id.PushBack(i), RemainLengths{}.PopFront());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ford_impl<1>
|
||||
{
|
||||
// F signature: F(Array<...> multi_id)
|
||||
// CurrentMultiIndex: Array<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == 1, "wrong!");
|
||||
|
||||
constexpr index_t last_length = RemainLengths{}.Front();
|
||||
|
||||
for(index_t i = 0; i < last_length; ++i)
|
||||
{
|
||||
f(current_multi_id.PushBack(i));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Lengths is Sequence<...>
|
||||
template <class Lengths>
|
||||
struct ford
|
||||
{
|
||||
// F signature: F(Array<...> multi_id)
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr index_t first_length = Lengths{}.Front();
|
||||
|
||||
for(index_t i = 0; i < first_length; ++i)
|
||||
{
|
||||
ford_impl<Lengths::GetSize() - 1>{}(f, Array<index_t, 1>{i}, Lengths{}.PopFront());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
24
composable_kernel/include/utility/integral_constant.hpp
Normal file
24
composable_kernel/include/utility/integral_constant.hpp
Normal file
@@ -0,0 +1,24 @@
|
||||
#ifndef CK_INTEGRAL_CONSTANT_HPP
|
||||
#define CK_INTEGRAL_CONSTANT_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class T, T N>
|
||||
struct integral_constant
|
||||
{
|
||||
static const T value = N;
|
||||
|
||||
__host__ __device__ constexpr T Get() const { return value; }
|
||||
};
|
||||
|
||||
template <class T, T X, T Y>
|
||||
__host__ __device__ constexpr auto operator+(integral_constant<T, X>, integral_constant<T, Y>)
|
||||
{
|
||||
return integral_constant<T, X + Y>{};
|
||||
}
|
||||
|
||||
template <index_t N>
|
||||
using Number = integral_constant<index_t, N>;
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
122
composable_kernel/include/utility/utility.hpp
Normal file
122
composable_kernel/include/utility/utility.hpp
Normal file
@@ -0,0 +1,122 @@
|
||||
#ifndef CK_UTILITY_HPP
|
||||
#define CK_UTILITY_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
|
||||
|
||||
__device__ index_t get_block_1d_id() { return blockIdx.x; }
|
||||
|
||||
template <class T1, class T2>
|
||||
struct is_same
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct is_same<T, T>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <class X, class Y>
|
||||
__host__ __device__ constexpr bool is_same_type(X, Y)
|
||||
{
|
||||
return is_same<X, Y>::value;
|
||||
}
|
||||
|
||||
namespace math {
|
||||
|
||||
template <class T, T s>
|
||||
struct scales
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a) const { return s * a; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct plus
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct minus
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct multiplies
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct integer_divide_ceiler
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const
|
||||
{
|
||||
static_assert(is_same<T, index_t>::value || is_same<T, int>::value, "wrong type");
|
||||
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ constexpr T integer_divide_ceil(T a, T b)
|
||||
{
|
||||
static_assert(is_same<T, index_t>::value || is_same<T, int>::value, "wrong type");
|
||||
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ constexpr T max(T x, T y)
|
||||
{
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <class T, class... Ts>
|
||||
__host__ __device__ constexpr T max(T x, Ts... xs)
|
||||
{
|
||||
static_assert(sizeof...(xs) > 0, "not enough argument");
|
||||
|
||||
auto y = max(xs...);
|
||||
|
||||
static_assert(is_same<decltype(y), T>::value, "not the same type");
|
||||
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ constexpr T min(T x, T y)
|
||||
{
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <class T, class... Ts>
|
||||
__host__ __device__ constexpr T min(T x, Ts... xs)
|
||||
{
|
||||
static_assert(sizeof...(xs) > 0, "not enough argument");
|
||||
|
||||
auto y = min(xs...);
|
||||
|
||||
static_assert(is_same<decltype(y), T>::value, "not the same type");
|
||||
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
// this is wrong
|
||||
// TODO: implement least common multiple properly, instead of calling max()
|
||||
template <class T, class... Ts>
|
||||
__host__ __device__ constexpr T lcm(T x, Ts... xs)
|
||||
{
|
||||
return max(x, xs...);
|
||||
}
|
||||
|
||||
} // namespace math
|
||||
} // namspace ck
|
||||
|
||||
#endif
|
||||
194
composable_kernel/include/utility/vector_type.hpp
Normal file
194
composable_kernel/include/utility/vector_type.hpp
Normal file
@@ -0,0 +1,194 @@
|
||||
#ifndef CK_VECTOR_TYPE_HPP
|
||||
#define CK_VECTOR_TYPE_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class T, index_t N>
|
||||
struct vector_type
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 1>
|
||||
{
|
||||
typedef float MemoryType;
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 1, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 2>
|
||||
{
|
||||
using MemoryType = float2_t;
|
||||
|
||||
union Data
|
||||
{
|
||||
MemoryType vector;
|
||||
float scalar[2];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 2, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
|
||||
__host__ __device__ static MemoryType Pack(float s0, float s1)
|
||||
{
|
||||
Data data;
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 4>
|
||||
{
|
||||
using MemoryType = float4_t;
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 4, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
template <>
|
||||
struct vector_type<half, 1>
|
||||
{
|
||||
using MemoryType = half;
|
||||
|
||||
__host__ __device__ static MemoryType Pack(half s) { return s; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 2>
|
||||
{
|
||||
using MemoryType = half2;
|
||||
|
||||
__host__ __device__ static MemoryType Pack(half s0, half s1)
|
||||
{
|
||||
union
|
||||
{
|
||||
MemoryType vector;
|
||||
half scalar[2];
|
||||
} data;
|
||||
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 4>
|
||||
{
|
||||
using MemoryType = float2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 8>
|
||||
{
|
||||
using MemoryType = float4;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char, 1>
|
||||
{
|
||||
using MemoryType = char;
|
||||
|
||||
__host__ __device__ static MemoryType Pack(char s) { return s; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char, 2>
|
||||
{
|
||||
using MemoryType = int16_t;
|
||||
|
||||
__host__ __device__ static MemoryType Pack(char s0, char s1)
|
||||
{
|
||||
union
|
||||
{
|
||||
MemoryType vector;
|
||||
char scalar[2];
|
||||
} data;
|
||||
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char, 4>
|
||||
{
|
||||
using MemoryType = int32_t;
|
||||
|
||||
__host__ __device__ static MemoryType Pack(char s0, char s1, char s2, char s3)
|
||||
{
|
||||
union
|
||||
{
|
||||
MemoryType vector;
|
||||
char scalar[4];
|
||||
} data;
|
||||
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
data.scalar[2] = s2;
|
||||
data.scalar[3] = s3;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char, 8>
|
||||
{
|
||||
using MemoryType = int64_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<int32_t, 2>
|
||||
{
|
||||
using MemoryType = int64_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char2, 2>
|
||||
{
|
||||
using MemoryType = char4;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char2, 4>
|
||||
{
|
||||
using MemoryType = int64_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char4, 1>
|
||||
{
|
||||
using MemoryType = int;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char4, 2>
|
||||
{
|
||||
using MemoryType = int64_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user