GEMM pipeline v2 (#317)

* format

* improving pipeline

* fix typo

* format

* adding thread group

* adding thread group

* adding thread group

* adding gemm pipeline

* tweak

* refactor

* refactor

* add missing type convert

* refactor

* refactor

* refactor

* clean

* fix build

* refactor

* format

* clean up

* use remove_cvref_t

* clean

* use pipeline_v2 for gemm kernel

* Remove inconsistent indent

* Fix compilation errors due to incomplete merge process

* Add missing include directives

* Fix compilation errors in currently unused files

* Add license in newly added files

* Re-format touched files by clang-format-10

* Fix wrong template argument count of DeviceGemm<>

* Use language construct to choose between types

* Use language construct to choose GEMM example instance

* Fix compilation error due to interface change

* Re-use type alias to avoid duplication

* Unify type alias usage in source file

* Only use v2 pipeline in one gridwise GEMM type

* Remove no-longer used include directives

* Add static_assert() to check pipeline type requirements

* Revert "Add static_assert() to check pipeline type requirements"

This reverts commit f0985f0a13.

* clean

* clean

* clean

* clean

Co-authored-by: Chao Liu <chao.liu2@amd.com>
Co-authored-by: shaojiewang <wsjmessi@163.com>
This commit is contained in:
Po Yen Chen
2022-07-09 04:55:14 +08:00
committed by GitHub
parent 763ca61581
commit 639147432b
5 changed files with 160 additions and 22 deletions

View File

@@ -0,0 +1,128 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace ck {
struct GridwiseGemmPipeline_v2
{
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
// TODO: improve applicability
return num_loop % 2 == 0;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return (num_loop / 2) > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
// global read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// move to 1
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// LDS write 0
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// global Read 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write 0
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// global Read 1
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
// GEMM i
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// move to i + 2
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// LDS write i + 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// global read i + 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
// LDS write i + 1
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// global read i + 2
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
++i;
} while(i < (num_loop - 2));
}
// tail
{
block_sync_lds();
// GEMM num_loop - 2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
// LDS write num_loop - 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
block_sync_lds();
// GEMM num_loop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
};
} // namespace ck

View File

@@ -9,6 +9,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
@@ -134,7 +135,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = GridwiseGemmPipeline_v1<NumGemmKPrefetchStage>;
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using GridwiseGemmPipe =
#if 1
remove_cvref_t<decltype(
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>())>;
#else
GridwiseGemmPipeline_v2;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
@@ -425,8 +433,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
// gridwise GEMM pipeline
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_v1_Selector<NumGemmKPrefetchStage, LoopSched>();
static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /