Grouped Gemm + SplitK + simplified Kernel Args (#669)

* simplify karg in device/grid split-k op

* fix mk_kn_mn instances

* add more instances

* B2C with 3D grid for KSplit

* Remove unused code.

* Use default B2C (3D grid) in grid gemm v2r4r2.

* Device gemm splitk use B2C map.

* Device GroupedGemmXdlSplitKCShuffle

* Example for GroupedGemm Xdl SplitK

* Introduce Device GroupedGemmSplitK

* Fix updating kbatch size.

* Add instance mk-nk-mn

* Enable set kbatch in profiler.

* Add GGemmSplitK mk-kn-mn instances

* Add more instances & split into multiple files.

* minor fix

* tuning

* clean

* disabled failed instances

* use pipe v2

* Ignore arg on not supported arch.

* fix warning

---------

Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Adam Osewski <aosewski@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: Jing Zhang <jizhan@amd.com>
Co-authored-by: root <root@ctr-ubbsmc15.amd.com>
This commit is contained in:
Adam Osewski
2023-04-24 22:43:36 +02:00
committed by GitHub
parent 8b9cbba823
commit 8bb2bb4a05
17 changed files with 1329 additions and 40 deletions

View File

@@ -5,6 +5,7 @@ add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp)
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
add_dependencies(example_grouped_gemm_xdl
@@ -12,7 +13,8 @@ add_dependencies(example_grouped_gemm_xdl
example_grouped_gemm_xdl_fp16
example_grouped_gemm_xdl_bfp16
example_grouped_gemm_xdl_int8
example_grouped_gemm_multiple_d_dl_fp16)
example_grouped_gemm_multiple_d_dl_fp16
example_grouped_gemm_xdl_splitk_fp16)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)

View File

@@ -0,0 +1,97 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
#include "run_grouped_gemm_example.inc"
int main(int argc, char* argv[])
{
ProblemSize problem_size;
ExecutionConfig config;
problem_size.group_count = 16;
problem_size.Ms = {
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
for(int i = 0; i < problem_size.group_count; i++)
{
problem_size.Ns.push_back(768);
problem_size.Ks.push_back(4608);
problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
}
if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
exit(0);
}
return !run_grouped_gemm(problem_size, config);
}

View File

@@ -147,6 +147,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
#else
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
c_tensors_device[i]->SetZero();
#endif
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());