mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 21:58:13 +00:00
Multiple ABD
This commit is contained in:
1
example/ck_tile/20_gemm_multi_abd/CMakeLists.txt
Normal file
1
example/ck_tile/20_gemm_multi_abd/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_executable(tile_example_gemm_multi_abd_fp16 EXCLUDE_FROM_ALL gemm_multi_abd_fp16.cpp)
|
||||
35
example/ck_tile/20_gemm_multi_abd/README.md
Normal file
35
example/ck_tile/20_gemm_multi_abd/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
#Multiple ABD GEMM
|
||||
|
||||
This folder contains example for Multiple ABD GEMM using ck_tile tile-programming implementation.
|
||||
|
||||
## build
|
||||
```
|
||||
#in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \
|
||||
leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
#The basic pipeline method on the gemm calculation
|
||||
make tile_example_gemm_multi_abd_fp16 -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_multi_abd_fp16`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-m M dimensions - (Default: 3840)
|
||||
-n N dimensions - (Default: 4096)
|
||||
-k K dimensions - (Default: 4096)
|
||||
-as_layout Tensor A layout (default:R)
|
||||
-bs_layout Tensor B layout (default:C)
|
||||
-ds_layout Tensor D layout (default:R)
|
||||
-e_layout Tensor E layout (default:R)
|
||||
-stride_as Tensor A strides - (Default: 0)
|
||||
-stride_bs Tensor B strides - (Default: 0)
|
||||
-stride_e Tensor C strides - (Default: 0)
|
||||
-stride_ds Tensor D strides - (Default: 0)
|
||||
-validate 0. No validation, 1. Validation on GPU. (Default: 1)
|
||||
-warmup Number of iterations before benchmark the kernel. (Default: 10)
|
||||
-repeat Number of iterations to benchmark the kernel. (Default: 100)
|
||||
-kbatch kbatch for SplitK. (Default: 1)
|
||||
```
|
||||
@@ -14,22 +14,22 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "multi_abd_gemm.hpp"
|
||||
#include "gemm_multi_abd_fp16.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
template <typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename EDataType,
|
||||
typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename AElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename ELayout,
|
||||
typename AsElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename BsElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::stream_config& s) -> float
|
||||
auto gemm_multiple_abd(const gemm_multiple_abd_kargs& args, const ck_tile::stream_config& s) -> float
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Memory friendly for Interwave scheduler
|
||||
@@ -80,7 +80,7 @@ auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::strea
|
||||
constexpr bool DoubleSmemBuffer = true;
|
||||
#endif
|
||||
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
@@ -98,7 +98,7 @@ auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::strea
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, AsLayout, BsLayout, CLayout>;
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, AsLayout, BsLayout, ELayout>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
@@ -106,7 +106,7 @@ auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::strea
|
||||
DoubleSmemBuffer,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
CLayout,
|
||||
ELayout,
|
||||
TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<AsDataType, BsDataType, AccDataType, GemmShape, Traits>;
|
||||
@@ -132,8 +132,8 @@ auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::strea
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<AsDataType,
|
||||
BsDataType,
|
||||
AccDataType,
|
||||
AElementWise,
|
||||
BElementWise,
|
||||
AsElementWise,
|
||||
BsElementWise,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
@@ -145,8 +145,8 @@ auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::strea
|
||||
ck_tile::CShuffleEpilogueProblem<AsDataType,
|
||||
BsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
EDataType,
|
||||
ELayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -292,6 +292,6 @@ auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::strea
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_multi_abd_gemm_example.inc"
|
||||
#include "run_gemm_multi_abd_fp16_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_multiple_abd_gemm_example(argc, argv); }
|
||||
@@ -42,7 +42,7 @@ using B1DataType = ck_tile::half_t;
|
||||
using D0DataType = ck_tile::half_t;
|
||||
using D1DataType = ck_tile::half_t;
|
||||
|
||||
using CDataType = ck_tile::half_t;
|
||||
using EDataType = ck_tile::half_t;
|
||||
|
||||
using AsDataType = ck_tile::tuple<A0DataType, A1DataType>;
|
||||
using BsDataType = ck_tile::tuple<B0DataType, B1DataType>;
|
||||
@@ -56,21 +56,34 @@ auto create_args(int argc, char* argv[])
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "4096", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Col by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_d", "0", "Tensor Ds stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("as_layout", "R", "As tensor data layout - Row by default")
|
||||
.insert("bs_layout", "C", "Bs tensor data layout - Col by default")
|
||||
.insert("ds_layout", "R", "Ds tensor data layout - Row by default")
|
||||
.insert("e_layout", "R", "E tensor data layout - Row by default")
|
||||
.insert("stride_as", "0", "Tensor A stride")
|
||||
.insert("stride_bs", "0", "Tensor B stride")
|
||||
.insert("stride_ds", "0", "Tensor Ds stride")
|
||||
.insert("stride_e", "0", "Tensor E stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on GPU")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel");
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("kbatch", "1", "kbatch for SplitK");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
using gemm_multiple_abd_kargs = ck_tile::GemmHostArgs<AsDataType::size(), BsDataType::size()>;
|
||||
|
||||
using multiple_abd_gemm_kargs = ck_tile::GemmHostArgs<AsDataType::size(), BsDataType::size()>;
|
||||
|
||||
float multiple_abd_gemm(const multiple_abd_gemm_kargs& kargs, const ck_tile::stream_config& s);
|
||||
template <typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename AElementWise,
|
||||
typename BElementWise,
|
||||
typename CDEElementWise>
|
||||
float gemm_multiple_abd(const gemm_multiple_abd_kargs& kargs, const ck_tile::stream_config& s);
|
||||
@@ -8,70 +8,70 @@ template <typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename EDataType,
|
||||
typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename AElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename ELayout,
|
||||
typename AsElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename BsElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_multi_abd_gemm(const std::array<const void*, AsDataType::size()>& as_m_k_dev_buf,
|
||||
float invoke_gemm_multi_abd(const std::array<const void*, AsDataType::size()>& as_m_k_dev_buf,
|
||||
const std::array<const void*, BsDataType::size()>& bs_k_n_dev_buf,
|
||||
[[maybe_unused]] const std::array<const void*, DsDataType::size()>& ds_m_n_dev_buf,
|
||||
void* c_m_n_dev_buf,
|
||||
void* e_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
const std::array<ck_tile::index_t, AsDataType::size()>& StrideAs,
|
||||
const std::array<ck_tile::index_t, BsDataType::size()>& StrideBs,
|
||||
[[maybe_unused]] const std::array<ck_tile::index_t, DsDataType::size()>& StrideDs,
|
||||
ck_tile::index_t StrideC,
|
||||
ck_tile::index_t StrideE,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
int n_repeat,
|
||||
int k_batch)
|
||||
{
|
||||
multiple_abd_gemm_kargs gemm_descs({as_m_k_dev_buf,
|
||||
gemm_multiple_abd_kargs gemm_descs({as_m_k_dev_buf,
|
||||
bs_k_n_dev_buf,
|
||||
//ds_m_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
/*kbatch */ 1,
|
||||
e_m_n_dev_buf,
|
||||
k_batch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
//StrideDs,
|
||||
StrideC});
|
||||
StrideE});
|
||||
|
||||
float ave_time = multiple_abd_gemm<AsDataType,
|
||||
float ave_time = gemm_multiple_abd<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
EDataType,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
AElementWise,
|
||||
BElementWise,
|
||||
ELayout,
|
||||
AsElementWise,
|
||||
BsElementWise,
|
||||
CDEElementWise>(
|
||||
gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Multiple-D Gemm"};
|
||||
//static constexpr ck_tile::index_t NumDTensor = DsDataType::size();
|
||||
std::string op_name{"Gemm Multiple-ABD"};
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
flop += std::size_t(2) * M * N * K;
|
||||
|
||||
num_btype += sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(CDataType) * M * N;
|
||||
num_btype += sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Multiple-D Gemm kernel with:\n";
|
||||
std::cout << "Run Gemm Multiple-ABD kernel with:\n";
|
||||
std::cout << "M =" << M << " N =" << N << " K =" << K << "\n";
|
||||
//std::cout << "StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC
|
||||
//std::cout << "StrideA = " << StrideA << " StrideB = " << StrideB << " StrideE = " << StrideE
|
||||
// << "\n";
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< "\n";
|
||||
@@ -85,8 +85,8 @@ template <typename A0Layout,
|
||||
typename B1Layout,
|
||||
typename D0Layout,
|
||||
typename D1Layout,
|
||||
typename CLayout>
|
||||
int run_multiple_abd_gemm_example_with_layouts(int argc,
|
||||
typename ELayout>
|
||||
int run_gemm_multi_abd_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const A0Layout a0_layout = A0Layout{},
|
||||
const A1Layout a1_layout = A1Layout{},
|
||||
@@ -94,7 +94,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
|
||||
const B1Layout b1_layout = B1Layout{},
|
||||
const D0Layout d0_layout = D0Layout{},
|
||||
const D1Layout d1_layout = D1Layout{},
|
||||
const CLayout c_layout = CLayout{})
|
||||
const ELayout e_layout = ELayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
@@ -112,10 +112,10 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t StrideA = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t StrideB = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t StrideD = arg_parser.get_int("stride_d");
|
||||
ck_tile::index_t StrideC = arg_parser.get_int("stride_c");
|
||||
ck_tile::index_t StrideA = arg_parser.get_int("stride_as");
|
||||
ck_tile::index_t StrideB = arg_parser.get_int("stride_bs");
|
||||
ck_tile::index_t StrideD = arg_parser.get_int("stride_ds");
|
||||
ck_tile::index_t StrideE = arg_parser.get_int("stride_e");
|
||||
|
||||
ck_tile::index_t StrideA0 = StrideA;
|
||||
ck_tile::index_t StrideA1 = StrideA;
|
||||
@@ -128,6 +128,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
|
||||
|
||||
const int n_warmup = arg_parser.get_int("warmup");
|
||||
const int n_repeat = arg_parser.get_int("repeat");
|
||||
const int k_batch = arg_parser.get_int("kbatch");
|
||||
|
||||
StrideA0 = f_get_default_stride(M, N, StrideA0, a0_layout);
|
||||
StrideA1 = f_get_default_stride(M, N, StrideA1, a1_layout);
|
||||
@@ -138,7 +139,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
|
||||
StrideD0 = f_get_default_stride(M, N, StrideD0, d0_layout);
|
||||
StrideD1 = f_get_default_stride(M, N, StrideD1, d1_layout);
|
||||
|
||||
StrideC = f_get_default_stride(M, N, StrideC, c_layout);
|
||||
StrideE = f_get_default_stride(M, N, StrideE, e_layout);
|
||||
|
||||
ck_tile::HostTensor<A0DataType> a0_m_k_tesnor(
|
||||
f_host_tensor_descriptor(M, K, StrideA0, a0_layout));
|
||||
@@ -155,8 +156,8 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
|
||||
ck_tile::HostTensor<D1DataType> d1_m_n_tensors(
|
||||
f_host_tensor_descriptor(M, N, StrideD1, d1_layout));
|
||||
|
||||
ck_tile::HostTensor<CDataType> e_m_n_device_result(
|
||||
f_host_tensor_descriptor(M, N, StrideC, c_layout));
|
||||
ck_tile::HostTensor<EDataType> e_m_n_device_result(
|
||||
f_host_tensor_descriptor(M, N, StrideE, e_layout));
|
||||
|
||||
ck_tile::FillUniformDistribution<A0DataType>{-1.f, 1.f}(a0_m_k_tesnor);
|
||||
ck_tile::FillUniformDistribution<A0DataType>{-1.f, 1.f}(a1_m_k_tesnor);
|
||||
@@ -176,7 +177,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
|
||||
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
|
||||
|
||||
a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data());
|
||||
a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data());
|
||||
@@ -187,7 +188,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
|
||||
d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data());
|
||||
d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data());
|
||||
|
||||
c_m_n_dev_buf.SetZero();
|
||||
e_m_n_dev_buf.SetZero();
|
||||
e_m_n_device_result.SetZero();
|
||||
|
||||
std::array<const void*, DsDataType::size()> as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(),
|
||||
@@ -203,36 +204,37 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
|
||||
std::array<ck_tile::index_t, BsDataType::size()> strideBs = {StrideB0, StrideB1};
|
||||
std::array<ck_tile::index_t, DsDataType::size()> strideDs = {StrideD0, StrideD1};
|
||||
|
||||
invoke_multi_abd_gemm<AsDataType,
|
||||
invoke_gemm_multi_abd<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
EDataType,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ELayout,
|
||||
AElementWiseFn,
|
||||
BElementWiseFn,
|
||||
CDEElementWiseFn>(as_ptr_buf,
|
||||
bs_ptr_buf,
|
||||
ds_ptr_buf,
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
e_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
strideAs,
|
||||
strideBs,
|
||||
strideDs,
|
||||
StrideC,
|
||||
StrideE,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
n_repeat,
|
||||
k_batch);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(e_m_n_device_result.data());
|
||||
e_m_n_dev_buf.FromDevice(e_m_n_device_result.data());
|
||||
|
||||
ck_tile::HostTensor<A0DataType> a_m_k_host_ref(f_host_tensor_descriptor(M, K, StrideA0, a0_layout));
|
||||
ck_tile::HostTensor<B0DataType> b_k_n_host_ref(f_host_tensor_descriptor(K, N, StrideB0, b0_layout));
|
||||
ck_tile::HostTensor<CDataType> e_m_n_host_ref(f_host_tensor_descriptor(M, N, StrideC, c_layout));
|
||||
ck_tile::HostTensor<EDataType> e_m_n_host_ref(f_host_tensor_descriptor(M, N, StrideE, e_layout));
|
||||
a_m_k_host_ref.SetZero();
|
||||
b_k_n_host_ref.SetZero();
|
||||
e_m_n_host_ref.SetZero();
|
||||
@@ -242,7 +244,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
|
||||
B0DataType,
|
||||
D0DataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
EDataType,
|
||||
AElementWiseFn,
|
||||
BElementWiseFn,
|
||||
CDEElementWiseFn>(
|
||||
@@ -279,18 +281,15 @@ int run_multiple_abd_gemm_example(int argc, char* argv[])
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string as_layout = arg_parser.get_str("as_layout");
|
||||
const std::string bs_layout = arg_parser.get_str("bs_layout");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
if(as_layout == "R" && bs_layout == "C")
|
||||
{
|
||||
return run_multiple_abd_gemm_example_with_layouts(
|
||||
return run_gemm_multi_abd_example_with_layouts(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
@@ -45,17 +45,17 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
std::conditional_t<sizeof(A0DataType) < sizeof(B0DataType), A0DataType, B0DataType>;
|
||||
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
|
||||
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
|
||||
// Use higher threshold
|
||||
@@ -1 +0,0 @@
|
||||
add_executable(tile_example_multi_abd_gemm EXCLUDE_FROM_ALL multi_abd_gemm.cpp)
|
||||
@@ -1,33 +0,0 @@
|
||||
#Multiple D GEMM
|
||||
|
||||
This folder contains example for Multiple D GEMM using ck_tile tile-programming implementation.
|
||||
|
||||
## build
|
||||
```
|
||||
#in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \
|
||||
leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
#The basic pipeline method on the gemm calculation
|
||||
make tile_example_multi_d_gemm -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_multi_d_gemm`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-m M dimensions - (Default: 3840)
|
||||
-n N dimensions - (Default: 4096)
|
||||
-k K dimensions - (Default: 4096)
|
||||
-a_layout Tensor A layout (default:R)
|
||||
-b_layout Tensor B layout (default:C)
|
||||
-c_layout Tensor C layout (default:R)
|
||||
-stride_a Tensor A strides - (Default: 0)
|
||||
-stride_b Tensor B strides - (Default: 0)
|
||||
-stride_c Tensor C strides - (Default: 0)
|
||||
-stride_d Tensor C strides - (Default: 0)
|
||||
-validate 0. No validation, 1. Validation on GPU. (Default: 1)
|
||||
-warmup Number of iterations before benchmark the kernel. (Default: 10)
|
||||
-repeat Number of iterations to benchmark the kernel. (Default: 100)
|
||||
```
|
||||
@@ -18,6 +18,6 @@ add_subdirectory(15_fused_moe)
|
||||
add_subdirectory(16_batched_gemm)
|
||||
add_subdirectory(17_grouped_gemm)
|
||||
add_subdirectory(18_flatmm)
|
||||
add_subdirectory(20_multi_abd_gemm)
|
||||
add_subdirectory(20_gemm_multi_abd)
|
||||
add_subdirectory(35_batched_transpose)
|
||||
add_subdirectory(36_copy)
|
||||
|
||||
@@ -13,6 +13,93 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments
|
||||
/// object. It contain all necessary information required to build proper kernel argument
|
||||
/// and launch kernel on GPU.
|
||||
/// This structure defines the GEMM problem configuration by stating all required information
|
||||
/// like M,N,K sizes and respective strides.
|
||||
/// NumDTensor describes the number of D tensors.
|
||||
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
|
||||
struct GemmHostArgs
|
||||
{
|
||||
CK_TILE_HOST GemmHostArgs() = default;
|
||||
CK_TILE_HOST GemmHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
|
||||
const std::array<const void*, NumBTensor>& bs_ptr_,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr_,
|
||||
void* e_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
const std::array<index_t, NumATensor>& stride_As_,
|
||||
const std::array<index_t, NumBTensor>& stride_Bs_,
|
||||
const std::array<index_t, NumDTensor>& stride_Ds_,
|
||||
index_t stride_E_)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
e_ptr(e_ptr_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
stride_A(stride_A_),
|
||||
stride_B(stride_B_),
|
||||
stride_Ds(stride_Ds_),
|
||||
stride_E(stride_E_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const std::array<const void*, NumATensor> as_ptr;
|
||||
const std::array<const void*, NumBTensor> bs_ptr;
|
||||
const std::array<const void*, NumDTensor> ds_ptr;
|
||||
void* e_ptr;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
const std::array<index_t, NumATensor> stride_As;
|
||||
const std::array<index_t, NumBTensor> stride_Bs;
|
||||
const std::array<index_t, NumDTensor> stride_Ds;
|
||||
index_t stride_E;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
/// @brief The GEMM kernel device arguments.
|
||||
template <typename AType = ck_tile::tuple<>, typename BType = ck_tile::tuple<>, typename DType = ck_tile::tuple<>>
|
||||
struct GemmKernelArgs
|
||||
{
|
||||
/// @brief The A input tensor's pointer to device memory.
|
||||
const AType* a_ptr;
|
||||
/// @brief The B input tensor's pointer to device memory.
|
||||
const BType* b_ptr;
|
||||
/// @brief The B input tensor's pointer to device memory.
|
||||
const DType ds_ptr;
|
||||
/// @brief The E output tensor's pointer to device memory.
|
||||
void* e_ptr;
|
||||
/// @brief GEMM's M dimension size.
|
||||
index_t M;
|
||||
/// @brief GEMM's N dimension size.
|
||||
index_t N;
|
||||
/// @brief GEMM's K dimension size.
|
||||
index_t K;
|
||||
/// @brief The distance between consecutive elements of non-contiguous dimension
|
||||
/// (in memory) of As tensor.
|
||||
std::array<index_t, AType::size()> stride_As;
|
||||
/// @brief The distance between consecutive elements of non-contiguous dimension
|
||||
/// (in memory) of Bs tensor.
|
||||
std::array<index_t, BType::size()> stride_Bs;
|
||||
/// @brief The distance between consecutive elements of non-contiguous dimension
|
||||
/// (in memory) of Ds tensor.
|
||||
std::array<index_t, DType::size()> stride_Ds;
|
||||
/// @brief The distance between consecutive elements of non-contiguous dimension
|
||||
/// (in memory) of E tensor.
|
||||
index_t stride_E;
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
template <index_t NumATensor = 1, index_t NumBTensor = 1>
|
||||
struct GemmHostArgs
|
||||
{
|
||||
@@ -130,7 +217,7 @@ struct GemmKernel
|
||||
using AsDataType = remove_cvref_t<typename GemmPipeline::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename GemmPipeline::BsDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
@@ -267,7 +354,7 @@ struct GemmKernel
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs<AsGridPointer, BsGridPointer>& kargs)
|
||||
{
|
||||
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value)
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
@@ -417,7 +504,7 @@ struct GemmKernel
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeGemmTensorViews(const AsGridPointer as_ptr,
|
||||
const BsGridPointer bs_ptr,
|
||||
CDataType* c_ptr,
|
||||
EDataType* c_ptr,
|
||||
const GemmKernelArgs<AsGridPointer, BsGridPointer>& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
@@ -668,7 +755,7 @@ struct GemmKernel
|
||||
*/
|
||||
CK_TILE_DEVICE static void RunGemm(const AsGridPointer as_ptr,
|
||||
const BsGridPointer bs_ptr,
|
||||
CDataType* c_ptr,
|
||||
EDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const GemmKernelArgs<AsGridPointer, BsGridPointer>& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
@@ -720,7 +807,7 @@ struct GemmKernel
|
||||
*/
|
||||
CK_TILE_DEVICE static void RunGemm2LDS(const AsGridPointer as_ptr,
|
||||
const BsGridPointer bs_ptr,
|
||||
CDataType* c_ptr,
|
||||
EDataType* c_ptr,
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
const GemmKernelArgs<AsGridPointer, BsGridPointer>& kargs,
|
||||
@@ -774,7 +861,7 @@ struct GemmKernel
|
||||
bs_ptr(i) = static_cast<const BDataType*>(kargs.bs_ptr[i]) + splitk_batch_offset.b_k_split_offset[i];
|
||||
});
|
||||
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
EDataType* c_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
@@ -784,7 +871,7 @@ struct GemmKernel
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(as_ptr,
|
||||
bs_ptr,
|
||||
@@ -801,7 +888,7 @@ struct GemmKernel
|
||||
{
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(as_ptr, bs_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user