Fused GEMM+GEMM (#351)

* initial stub for gemm_gemm_xdl_cshuffle

* set up example code

* compiles

* prevent integer overflow

* harmonize interface between ref_gemm and ref_batched_gemm

* batched_gemm_gemm

* fix example

* host tensor gen: diagonal pattern in lowest two-dimensions only

* make c descriptors containing only integral constants

* clean up

* add BlockwiseGemmXdlops_v2 while exploring an unified approach

* implement proper interface

* tidy up example

* fix compilation warnings

* coarsely controlled 2nd gemm padding

* remove rocm-cmake's hard requirement for certain revision

* clang-format

* resolve merge conflict

* fix compilation error on gfx10

* adds acc0 elementwise op to interface

* add gemm_gemm instances and tests

* avoid LDS data hazard

* fix build

Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
Anthony Chang
2022-08-13 22:18:58 +08:00
committed by GitHub
parent 10b3278b05
commit c20a75b07d
32 changed files with 2822 additions and 689 deletions

View File

@@ -0,0 +1,93 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
Col,
Row,
Row,
F16,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout,
typename B0Layout,
typename B1Layout,
typename CLayout,
typename ADataType,
typename B0DataType,
typename B1DataType,
typename CDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmGemm<ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
PassThrough,
PassThrough,
PassThrough,
PassThrough,
PassThrough>>
{
using DeviceOp = DeviceBatchedGemmGemm<ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
PassThrough,
PassThrough,
PassThrough,
PassThrough,
PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
{
add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -13,6 +13,7 @@ add_subdirectory(gemm_reduce)
add_subdirectory(gemm_bias_add_reduce)
add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm)
add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(grouped_gemm)
add_subdirectory(contraction_scale)

View File

@@ -0,0 +1,8 @@
set(DEVICE_BATCHED_GEMM_GEMM_INSTANCE_SOURCE
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
)
add_instance_library(device_batched_gemm_gemm_instance OBJECT ${DEVICE_BATCHED_GEMM_GEMM_INSTANCE_SOURCE})
target_compile_features(device_batched_gemm_gemm_instance PUBLIC)
set_target_properties(device_batched_gemm_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(device_batched_gemm_gemm_instance)

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = std::tuple<
// clang-format off
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
Col,
Row,
Row,
F16,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck