mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
adding padding to implicit gemm v1r3
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
@@ -79,21 +79,21 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
|
||||
constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock);
|
||||
constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock);
|
||||
constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock);
|
||||
constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
|
||||
Sequence<KBlockWork, HBlockWork, WBlockWork, NBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
|
||||
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
|
||||
const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
|
||||
const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
|
||||
const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock;
|
||||
const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock;
|
||||
const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
@@ -74,14 +74,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
// assert for LDS double buffer
|
||||
static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided");
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % (2 * CPerBlock) == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
|
||||
@@ -0,0 +1,420 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_PADDED_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_PADDED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
class LowerPads,
|
||||
class UpperPads,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_CHWN,
|
||||
class InBlockCopyClusterLengths_CHWN,
|
||||
index_t InBlockCopyDataPerAccess_N,
|
||||
class WeiBlockCopySubLengths_CK,
|
||||
class WeiBlockCopyClusterLengths_CK,
|
||||
index_t WeiBlockCopyDataPerAccess_K,
|
||||
index_t OutThreadCopyDataPerAccess_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
|
||||
|
||||
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock);
|
||||
constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock);
|
||||
constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock);
|
||||
constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KBlockWork, HBlockWork, WBlockWork, NBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock;
|
||||
const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock;
|
||||
const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3);
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_N,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{}, Number<max_align>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{}, Number<max_align>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_CHWN,
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
3,
|
||||
InBlockCopyDataPerAccess_N,
|
||||
InBlockCopyDataPerAccess_N>({0, 0, 0, 0},
|
||||
{0, 0, 0, 0});
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_CK,
|
||||
WeiBlockCopyClusterLengths_CK,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0});
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space = in_c_h_w_n_block_desc.GetElementSpace();
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace();
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
|
||||
|
||||
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_matrix_set_zero(c_k_wn_thread_mtx_desc, p_out_thread);
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global +
|
||||
wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset +=
|
||||
CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset +=
|
||||
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto fwd) {
|
||||
// fwd do nothing but perfect forwarding.
|
||||
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
|
||||
// being instantiated here
|
||||
static_assert(
|
||||
(fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = fwd(out_k_h_w_n_global_desc)
|
||||
.Fold(I3, Number<N1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<W2>{})
|
||||
.Fold(I0, Number<K1>{}, Number<K2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<1>{}, Number<N2>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"a: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"a: out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
Float* p_out_thread_on_global = p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin);
|
||||
|
||||
#if 1
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#elif 0
|
||||
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#endif
|
||||
}).Else([&](auto fwd) {
|
||||
static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / fwd(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
fwd(out_k_h_w_n_global_desc)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<W2>{}, Number<W3>{})
|
||||
.Fold(I0, Number<K1>{}, Number<K2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc =
|
||||
fwd(out_k_h_w_n_thread_desc)
|
||||
.Fold(I3, Number<N1>{})
|
||||
.Fold(I2, Number<W1>{}, Number<1>{}, Number<W3>{})
|
||||
.Fold(I0, Number<1>{}, Number<K2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"b: out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"b: out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
Float* p_out_thread_on_global = p_out_global +
|
||||
out_k_h_w_n_global_desc.GetOffsetFromMultiIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin);
|
||||
|
||||
#if 1
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#elif 0
|
||||
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#endif
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -301,14 +301,14 @@ struct TensorCoordinate
|
||||
private:
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor<Ts...>)
|
||||
{
|
||||
return NormalTensorCoordinate<ConstantTensorDescriptor<Ts...>>();
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor<Ts...>)
|
||||
{
|
||||
return MergedTensorCoordinate<ConstantMergedTensorDescriptor<Ts...>>();
|
||||
}
|
||||
|
||||
@@ -472,55 +472,54 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
#endif
|
||||
|
||||
constexpr index_t GridSize =
|
||||
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
|
||||
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
|
||||
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
constexpr auto gridwise_conv =
|
||||
#if 0
|
||||
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
#elif 1
|
||||
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
#endif
|
||||
<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_chwn_desc),
|
||||
decltype(wei_cyxk_desc),
|
||||
decltype(out_khwn_desc),
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_CHWN,
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
InBlockCopyDataPerAccess_N,
|
||||
WeiBlockCopySubLengths_CK,
|
||||
WeiBlockCopyClusterLengths_CK,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
OutThreadCopyDataPerAccess_N>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
constexpr auto gridwise_conv =
|
||||
#if 0
|
||||
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
#elif 1
|
||||
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
#endif
|
||||
<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_chwn_desc),
|
||||
decltype(wei_cyxk_desc),
|
||||
decltype(out_khwn_desc),
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_CHWN,
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
InBlockCopyDataPerAccess_N,
|
||||
WeiBlockCopySubLengths_CK,
|
||||
WeiBlockCopyClusterLengths_CK,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
OutThreadCopyDataPerAccess_N>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc, class LowerPads, class UpperPads>
|
||||
void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
LowerPads,
|
||||
UpperPads,
|
||||
index_t nrepeat)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_desc = InDesc{};
|
||||
constexpr auto wei_kcyx_desc = WeiDesc{};
|
||||
constexpr auto out_nkhw_desc = OutDesc{};
|
||||
|
||||
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
|
||||
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
|
||||
|
||||
// reorder weight
|
||||
auto wei_cyxk_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Y, X, K>{});
|
||||
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
|
||||
|
||||
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
|
||||
|
||||
auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) {
|
||||
wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
// reorder input
|
||||
auto in_chwn_desc = make_ConstantTensorDescriptor_packed(Sequence<C, Hi, Wi, N>{});
|
||||
ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: ");
|
||||
|
||||
Tensor<T> in_chwn(make_TensorDescriptor(in_chwn_desc));
|
||||
|
||||
auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) {
|
||||
in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
// output
|
||||
auto out_khwn_desc = make_ConstantTensorDescriptor_packed(Sequence<K, Ho, Wo, N>{});
|
||||
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
|
||||
|
||||
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace());
|
||||
DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace());
|
||||
DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace());
|
||||
|
||||
in_chwn_device_buf.ToDevice(in_chwn.mData.data());
|
||||
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
|
||||
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
|
||||
|
||||
#if 1
|
||||
// v1r3, 3x3, 32x32, 1x1 pad
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>;
|
||||
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>;
|
||||
constexpr index_t InBlockCopyDataPerAccess_N = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_CK = Sequence<2, 4>;
|
||||
using WeiBlockCopyClusterLengths_CK = Sequence<4, 32>;
|
||||
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerAccess_N = 2;
|
||||
#endif
|
||||
|
||||
constexpr index_t GridSize =
|
||||
(N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
constexpr auto gridwise_conv =
|
||||
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_chwn_desc),
|
||||
decltype(wei_cyxk_desc),
|
||||
decltype(out_khwn_desc),
|
||||
LowerPads,
|
||||
UpperPads,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_CHWN,
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
InBlockCopyDataPerAccess_N,
|
||||
WeiBlockCopySubLengths_CK,
|
||||
WeiBlockCopyClusterLengths_CK,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
OutThreadCopyDataPerAccess_N>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
}
|
||||
|
||||
out_khwn_device_buf.FromDevice(out_khwn.mData.data());
|
||||
|
||||
// reorder output
|
||||
auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) {
|
||||
out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "host_conv.hpp"
|
||||
#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
|
||||
#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp"
|
||||
//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
|
||||
//#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
|
||||
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
|
||||
@@ -71,7 +72,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 1536;
|
||||
constexpr index_t HI = 8;
|
||||
@@ -367,9 +368,19 @@ int main(int argc, char* argv[])
|
||||
#if 0
|
||||
device_convolution_direct_v2_nchw_kcyx_nkhw
|
||||
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(
|
||||
in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
lower_pads,
|
||||
upper_pads,
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(
|
||||
in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
@@ -419,16 +430,6 @@ int main(int argc, char* argv[])
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
nrepeat);
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
lower_pads,
|
||||
upper_pads,
|
||||
nrepeat);
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
|
||||
Reference in New Issue
Block a user