Files
composable_kernel/composable_kernel/include/tensor_operation/gridwise_gemm.hpp
Chao Liu 4414e495ed backward data (#7)
* enabled atomic add in tensor copy
* added gridwise GEMM
* added backward data conv using GEMM + atomic
* added backward data conv using GEMM, no atomic


[ROCm/composable_kernel commit: 8f5f64960e]
2019-12-03 01:16:12 -06:00

326 lines
15 KiB
C++

#ifndef CK_GRIDWISE_GEMM_HPP
#define CK_GRIDWISE_GEMM_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace ck {
template <index_t GridSize,
index_t BlockSize,
typename Float,
typename AccFloat,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
InMemoryDataOperation CGlobalMemoryDataOperation,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t MLevel0Cluster,
index_t NLevel0Cluster,
index_t MLevel1Cluster,
index_t NLevel1Cluster,
index_t KPerThreadLoop,
index_t ThreadGemmDataPerReadM,
index_t ThreadGemmDataPerReadN,
typename ABlockCopySubLengths_K_M,
typename ABlockCopyClusterLengths_K_M,
index_t ABlockCopyDataPerAccess_M,
typename BBlockCopySubLengths_K_N,
typename BBlockCopyClusterLengths_K_N,
index_t BBlockCopyDataPerAccess_N,
index_t CThreadCopyDataPerAccess_N>
struct GridwiseGemmTransposedANormalBNormalC_v1r1
{
__device__ void Run(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto a_k_m_global_desc = AGlobalDesc{};
constexpr auto b_k_n_global_desc = BGlobalDesc{};
constexpr auto c_m_n_global_desc = CGlobalDesc{};
constexpr auto K = a_k_m_global_desc.GetLengths()[0];
constexpr auto M = a_k_m_global_desc.GetLengths()[1];
constexpr auto N = b_k_n_global_desc.GetLengths()[1];
// lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDataPerAccess_M,
BBlockCopyDataPerAccess_N,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN);
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t m_block_data_on_global = block_work_id[0] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[1] * NPerBlock;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
decltype(a_k_m_block_desc.GetLengths()),
ABlockCopySubLengths_K_M,
ABlockCopyClusterLengths_K_M,
Sequence<0, 1>,
Sequence<0, 1>,
Sequence<0, 1>,
1,
1,
ABlockCopyDataPerAccess_M,
ABlockCopyDataPerAccess_M,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, m_block_data_on_global}, {0, 0});
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
decltype(b_k_n_block_desc.GetLengths()),
BBlockCopySubLengths_K_N,
BBlockCopyClusterLengths_K_N,
Sequence<0, 1>,
Sequence<0, 1>,
Sequence<0, 1>,
1,
1,
BBlockCopyDataPerAccess_N,
BBlockCopyDataPerAccess_N,
AddressSpace::global,
AddressSpace::vgpr,
AddressSpace::lds,
InMemoryDataOperation::none>(
{0, n_block_data_on_global}, {0, 0});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(a_k_m_block_desc);
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc);
// sanity check
static_assert(MPerBlock % (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat =
MPerBlock / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster);
constexpr index_t GemmNRepeat =
NPerBlock / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * MPerThreadSubC>{}, Number<GemmNRepeat * NPerThreadSubC>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
decltype(c_m0m1_n0n1_thread_mtx_desc),
MPerThreadSubC,
NPerThreadSubC,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
KPerThreadLoop,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
__shared__ Float p_a_block_double[2 * a_block_space];
__shared__ Float p_b_block_double[2 * b_block_space];
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double);
}
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
b_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
b_blockwise_copy.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
// input: register to global memory
{
constexpr index_t M1 = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t M0 = M / M1;
constexpr index_t N1 = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
constexpr index_t N0 = N / N1;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, MPerThreadSubC, GemmNRepeat, NPerThreadSubC>{});
constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor(
c_m_n_global_desc,
make_tuple(UnMerge<Sequence<M0, M1>>{}, UnMerge<Sequence<N0, N1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
Sequence<0, 1, 2, 3>,
3,
CThreadCopyDataPerAccess_N,
CThreadCopyDataPerAccess_N,
AddressSpace::vgpr,
AddressSpace::global,
CGlobalMemoryDataOperation>(
{0, 0, 0, 0},
{m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1})
.Run(p_c_thread, p_c_global);
}
}
};
} // namespace ck
#endif