Multiple ABD

This commit is contained in:
Mateusz Ozga
2025-06-03 13:23:34 +00:00
parent ee8f3d3323
commit f4e3fb5915
10 changed files with 225 additions and 124 deletions

View File

@@ -0,0 +1 @@
add_executable(tile_example_gemm_multi_abd_fp16 EXCLUDE_FROM_ALL gemm_multi_abd_fp16.cpp)

View File

@@ -0,0 +1,35 @@
#Multiple ABD GEMM
This folder contains example for Multiple ABD GEMM using ck_tile tile-programming implementation.
## build
```
#in the root of ck_tile
mkdir build && cd build
#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \
leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
#The basic pipeline method on the gemm calculation
make tile_example_gemm_multi_abd_fp16 -j
```
This will result in an executable `build/bin/tile_example_gemm_multi_abd_fp16`
## example
```
args:
-m M dimensions - (Default: 3840)
-n N dimensions - (Default: 4096)
-k K dimensions - (Default: 4096)
-as_layout Tensor A layout (default:R)
-bs_layout Tensor B layout (default:C)
-ds_layout Tensor D layout (default:R)
-e_layout Tensor E layout (default:R)
-stride_as Tensor A strides - (Default: 0)
-stride_bs Tensor B strides - (Default: 0)
-stride_e Tensor C strides - (Default: 0)
-stride_ds Tensor D strides - (Default: 0)
-validate 0. No validation, 1. Validation on GPU. (Default: 1)
-warmup Number of iterations before benchmark the kernel. (Default: 10)
-repeat Number of iterations to benchmark the kernel. (Default: 100)
-kbatch kbatch for SplitK. (Default: 1)
```

View File

@@ -14,22 +14,22 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "multi_abd_gemm.hpp"
#include "gemm_multi_abd_fp16.hpp"
#include "utils.hpp"
template <typename AsDataType,
typename BsDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename EDataType,
typename AsLayout,
typename BsLayout,
typename DsLayout,
typename CLayout,
typename AElementWise = ck_tile::element_wise::PassThrough,
typename BElementWise = ck_tile::element_wise::PassThrough,
typename ELayout,
typename AsElementWise = ck_tile::element_wise::PassThrough,
typename BsElementWise = ck_tile::element_wise::PassThrough,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::stream_config& s) -> float
auto gemm_multiple_abd(const gemm_multiple_abd_kargs& args, const ck_tile::stream_config& s) -> float
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
@@ -80,7 +80,7 @@ auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::strea
constexpr bool DoubleSmemBuffer = true;
#endif
constexpr bool kPadM = false;
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
@@ -98,7 +98,7 @@ auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::strea
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, AsLayout, BsLayout, CLayout>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, AsLayout, BsLayout, ELayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
@@ -106,7 +106,7 @@ auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::strea
DoubleSmemBuffer,
AsLayout,
BsLayout,
CLayout,
ELayout,
TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<AsDataType, BsDataType, AccDataType, GemmShape, Traits>;
@@ -132,8 +132,8 @@ auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::strea
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<AsDataType,
BsDataType,
AccDataType,
AElementWise,
BElementWise,
AsElementWise,
BsElementWise,
GemmShape,
GemmUniversalTraits,
scheduler,
@@ -145,8 +145,8 @@ auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::strea
ck_tile::CShuffleEpilogueProblem<AsDataType,
BsDataType,
AccDataType,
CDataType,
CLayout,
EDataType,
ELayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -292,6 +292,6 @@ auto multiple_abd_gemm(const multiple_abd_gemm_kargs& args, const ck_tile::strea
return ave_time;
}
#include "run_multi_abd_gemm_example.inc"
#include "run_gemm_multi_abd_fp16_example.inc"
int main(int argc, char* argv[]) { return !run_multiple_abd_gemm_example(argc, argv); }

View File

@@ -42,7 +42,7 @@ using B1DataType = ck_tile::half_t;
using D0DataType = ck_tile::half_t;
using D1DataType = ck_tile::half_t;
using CDataType = ck_tile::half_t;
using EDataType = ck_tile::half_t;
using AsDataType = ck_tile::tuple<A0DataType, A1DataType>;
using BsDataType = ck_tile::tuple<B0DataType, B1DataType>;
@@ -56,21 +56,34 @@ auto create_args(int argc, char* argv[])
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "4096", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Col by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_d", "0", "Tensor Ds stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("as_layout", "R", "As tensor data layout - Row by default")
.insert("bs_layout", "C", "Bs tensor data layout - Col by default")
.insert("ds_layout", "R", "Ds tensor data layout - Row by default")
.insert("e_layout", "R", "E tensor data layout - Row by default")
.insert("stride_as", "0", "Tensor A stride")
.insert("stride_bs", "0", "Tensor B stride")
.insert("stride_ds", "0", "Tensor Ds stride")
.insert("stride_e", "0", "Tensor E stride")
.insert("v", "1", "0. No validation, 1. Validation on GPU")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel");
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("kbatch", "1", "kbatch for SplitK");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
using gemm_multiple_abd_kargs = ck_tile::GemmHostArgs<AsDataType::size(), BsDataType::size()>;
using multiple_abd_gemm_kargs = ck_tile::GemmHostArgs<AsDataType::size(), BsDataType::size()>;
float multiple_abd_gemm(const multiple_abd_gemm_kargs& kargs, const ck_tile::stream_config& s);
template <typename AsDataType,
typename BsDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename AsLayout,
typename BsLayout,
typename DsLayout,
typename CLayout,
typename AElementWise,
typename BElementWise,
typename CDEElementWise>
float gemm_multiple_abd(const gemm_multiple_abd_kargs& kargs, const ck_tile::stream_config& s);

View File

@@ -8,70 +8,70 @@ template <typename AsDataType,
typename BsDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename EDataType,
typename AsLayout,
typename BsLayout,
typename DsLayout,
typename CLayout,
typename AElementWise = ck_tile::element_wise::PassThrough,
typename BElementWise = ck_tile::element_wise::PassThrough,
typename ELayout,
typename AsElementWise = ck_tile::element_wise::PassThrough,
typename BsElementWise = ck_tile::element_wise::PassThrough,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_multi_abd_gemm(const std::array<const void*, AsDataType::size()>& as_m_k_dev_buf,
float invoke_gemm_multi_abd(const std::array<const void*, AsDataType::size()>& as_m_k_dev_buf,
const std::array<const void*, BsDataType::size()>& bs_k_n_dev_buf,
[[maybe_unused]] const std::array<const void*, DsDataType::size()>& ds_m_n_dev_buf,
void* c_m_n_dev_buf,
void* e_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
const std::array<ck_tile::index_t, AsDataType::size()>& StrideAs,
const std::array<ck_tile::index_t, BsDataType::size()>& StrideBs,
[[maybe_unused]] const std::array<ck_tile::index_t, DsDataType::size()>& StrideDs,
ck_tile::index_t StrideC,
ck_tile::index_t StrideE,
int n_warmup,
int n_repeat)
int n_repeat,
int k_batch)
{
multiple_abd_gemm_kargs gemm_descs({as_m_k_dev_buf,
gemm_multiple_abd_kargs gemm_descs({as_m_k_dev_buf,
bs_k_n_dev_buf,
//ds_m_n_dev_buf,
c_m_n_dev_buf,
/*kbatch */ 1,
e_m_n_dev_buf,
k_batch,
M,
N,
K,
StrideAs,
StrideBs,
//StrideDs,
StrideC});
StrideE});
float ave_time = multiple_abd_gemm<AsDataType,
float ave_time = gemm_multiple_abd<AsDataType,
BsDataType,
DsDataType,
AccDataType,
CDataType,
EDataType,
AsLayout,
BsLayout,
DsLayout,
CLayout,
AElementWise,
BElementWise,
ELayout,
AsElementWise,
BsElementWise,
CDEElementWise>(
gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Multiple-D Gemm"};
//static constexpr ck_tile::index_t NumDTensor = DsDataType::size();
std::string op_name{"Gemm Multiple-ABD"};
std::size_t flop = 0, num_btype = 0;
flop += std::size_t(2) * M * N * K;
num_btype += sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(CDataType) * M * N;
num_btype += sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Run Multiple-D Gemm kernel with:\n";
std::cout << "Run Gemm Multiple-ABD kernel with:\n";
std::cout << "M =" << M << " N =" << N << " K =" << K << "\n";
//std::cout << "StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC
//std::cout << "StrideA = " << StrideA << " StrideB = " << StrideB << " StrideE = " << StrideE
// << "\n";
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< "\n";
@@ -85,8 +85,8 @@ template <typename A0Layout,
typename B1Layout,
typename D0Layout,
typename D1Layout,
typename CLayout>
int run_multiple_abd_gemm_example_with_layouts(int argc,
typename ELayout>
int run_gemm_multi_abd_example_with_layouts(int argc,
char* argv[],
const A0Layout a0_layout = A0Layout{},
const A1Layout a1_layout = A1Layout{},
@@ -94,7 +94,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
const B1Layout b1_layout = B1Layout{},
const D0Layout d0_layout = D0Layout{},
const D1Layout d1_layout = D1Layout{},
const CLayout c_layout = CLayout{})
const ELayout e_layout = ELayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
@@ -112,10 +112,10 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t StrideA = arg_parser.get_int("stride_a");
ck_tile::index_t StrideB = arg_parser.get_int("stride_b");
ck_tile::index_t StrideD = arg_parser.get_int("stride_d");
ck_tile::index_t StrideC = arg_parser.get_int("stride_c");
ck_tile::index_t StrideA = arg_parser.get_int("stride_as");
ck_tile::index_t StrideB = arg_parser.get_int("stride_bs");
ck_tile::index_t StrideD = arg_parser.get_int("stride_ds");
ck_tile::index_t StrideE = arg_parser.get_int("stride_e");
ck_tile::index_t StrideA0 = StrideA;
ck_tile::index_t StrideA1 = StrideA;
@@ -128,6 +128,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
const int n_warmup = arg_parser.get_int("warmup");
const int n_repeat = arg_parser.get_int("repeat");
const int k_batch = arg_parser.get_int("kbatch");
StrideA0 = f_get_default_stride(M, N, StrideA0, a0_layout);
StrideA1 = f_get_default_stride(M, N, StrideA1, a1_layout);
@@ -138,7 +139,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
StrideD0 = f_get_default_stride(M, N, StrideD0, d0_layout);
StrideD1 = f_get_default_stride(M, N, StrideD1, d1_layout);
StrideC = f_get_default_stride(M, N, StrideC, c_layout);
StrideE = f_get_default_stride(M, N, StrideE, e_layout);
ck_tile::HostTensor<A0DataType> a0_m_k_tesnor(
f_host_tensor_descriptor(M, K, StrideA0, a0_layout));
@@ -155,8 +156,8 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
ck_tile::HostTensor<D1DataType> d1_m_n_tensors(
f_host_tensor_descriptor(M, N, StrideD1, d1_layout));
ck_tile::HostTensor<CDataType> e_m_n_device_result(
f_host_tensor_descriptor(M, N, StrideC, c_layout));
ck_tile::HostTensor<EDataType> e_m_n_device_result(
f_host_tensor_descriptor(M, N, StrideE, e_layout));
ck_tile::FillUniformDistribution<A0DataType>{-1.f, 1.f}(a0_m_k_tesnor);
ck_tile::FillUniformDistribution<A0DataType>{-1.f, 1.f}(a1_m_k_tesnor);
@@ -176,7 +177,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data());
a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data());
@@ -187,7 +188,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data());
d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data());
c_m_n_dev_buf.SetZero();
e_m_n_dev_buf.SetZero();
e_m_n_device_result.SetZero();
std::array<const void*, DsDataType::size()> as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(),
@@ -203,36 +204,37 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
std::array<ck_tile::index_t, BsDataType::size()> strideBs = {StrideB0, StrideB1};
std::array<ck_tile::index_t, DsDataType::size()> strideDs = {StrideD0, StrideD1};
invoke_multi_abd_gemm<AsDataType,
invoke_gemm_multi_abd<AsDataType,
BsDataType,
DsDataType,
AccDataType,
CDataType,
EDataType,
AsLayout,
BsLayout,
DsLayout,
CLayout,
ELayout,
AElementWiseFn,
BElementWiseFn,
CDEElementWiseFn>(as_ptr_buf,
bs_ptr_buf,
ds_ptr_buf,
c_m_n_dev_buf.GetDeviceBuffer(),
e_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
strideAs,
strideBs,
strideDs,
StrideC,
StrideE,
n_warmup,
n_repeat);
n_repeat,
k_batch);
c_m_n_dev_buf.FromDevice(e_m_n_device_result.data());
e_m_n_dev_buf.FromDevice(e_m_n_device_result.data());
ck_tile::HostTensor<A0DataType> a_m_k_host_ref(f_host_tensor_descriptor(M, K, StrideA0, a0_layout));
ck_tile::HostTensor<B0DataType> b_k_n_host_ref(f_host_tensor_descriptor(K, N, StrideB0, b0_layout));
ck_tile::HostTensor<CDataType> e_m_n_host_ref(f_host_tensor_descriptor(M, N, StrideC, c_layout));
ck_tile::HostTensor<EDataType> e_m_n_host_ref(f_host_tensor_descriptor(M, N, StrideE, e_layout));
a_m_k_host_ref.SetZero();
b_k_n_host_ref.SetZero();
e_m_n_host_ref.SetZero();
@@ -242,7 +244,7 @@ int run_multiple_abd_gemm_example_with_layouts(int argc,
B0DataType,
D0DataType,
AccDataType,
CDataType,
EDataType,
AElementWiseFn,
BElementWiseFn,
CDEElementWiseFn>(
@@ -279,18 +281,15 @@ int run_multiple_abd_gemm_example(int argc, char* argv[])
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string as_layout = arg_parser.get_str("as_layout");
const std::string bs_layout = arg_parser.get_str("bs_layout");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C")
if(as_layout == "R" && bs_layout == "C")
{
return run_multiple_abd_gemm_example_with_layouts(
return run_gemm_multi_abd_example_with_layouts(
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}, Row{}, Row{});
}
else

View File

@@ -45,17 +45,17 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
std::conditional_t<sizeof(A0DataType) < sizeof(B0DataType), A0DataType, B0DataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
max_accumulated_value, kbatch);
// Use higher threshold

View File

@@ -1 +0,0 @@
add_executable(tile_example_multi_abd_gemm EXCLUDE_FROM_ALL multi_abd_gemm.cpp)

View File

@@ -1,33 +0,0 @@
#Multiple D GEMM
This folder contains example for Multiple D GEMM using ck_tile tile-programming implementation.
## build
```
#in the root of ck_tile
mkdir build && cd build
#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \
leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
#The basic pipeline method on the gemm calculation
make tile_example_multi_d_gemm -j
```
This will result in an executable `build/bin/tile_example_multi_d_gemm`
## example
```
args:
-m M dimensions - (Default: 3840)
-n N dimensions - (Default: 4096)
-k K dimensions - (Default: 4096)
-a_layout Tensor A layout (default:R)
-b_layout Tensor B layout (default:C)
-c_layout Tensor C layout (default:R)
-stride_a Tensor A strides - (Default: 0)
-stride_b Tensor B strides - (Default: 0)
-stride_c Tensor C strides - (Default: 0)
-stride_d Tensor C strides - (Default: 0)
-validate 0. No validation, 1. Validation on GPU. (Default: 1)
-warmup Number of iterations before benchmark the kernel. (Default: 10)
-repeat Number of iterations to benchmark the kernel. (Default: 100)
```

View File

@@ -18,6 +18,6 @@ add_subdirectory(15_fused_moe)
add_subdirectory(16_batched_gemm)
add_subdirectory(17_grouped_gemm)
add_subdirectory(18_flatmm)
add_subdirectory(20_multi_abd_gemm)
add_subdirectory(20_gemm_multi_abd)
add_subdirectory(35_batched_transpose)
add_subdirectory(36_copy)

View File

@@ -13,6 +13,93 @@
namespace ck_tile {
/// @brief The GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments
/// object. It contain all necessary information required to build proper kernel argument
/// and launch kernel on GPU.
/// This structure defines the GEMM problem configuration by stating all required information
/// like M,N,K sizes and respective strides.
/// NumDTensor describes the number of D tensors.
template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
struct GemmHostArgs
{
CK_TILE_HOST GemmHostArgs() = default;
CK_TILE_HOST GemmHostArgs(const std::array<const void*, NumATensor>& as_ptr_,
const std::array<const void*, NumBTensor>& bs_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
const std::array<index_t, NumATensor>& stride_As_,
const std::array<index_t, NumBTensor>& stride_Bs_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
ds_ptr(ds_ptr_),
e_ptr(e_ptr_),
M(M_),
N(N_),
K(K_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_Ds(stride_Ds_),
stride_E(stride_E_),
k_batch(k_batch_)
{
}
const std::array<const void*, NumATensor> as_ptr;
const std::array<const void*, NumBTensor> bs_ptr;
const std::array<const void*, NumDTensor> ds_ptr;
void* e_ptr;
index_t M;
index_t N;
index_t K;
const std::array<index_t, NumATensor> stride_As;
const std::array<index_t, NumBTensor> stride_Bs;
const std::array<index_t, NumDTensor> stride_Ds;
index_t stride_E;
index_t k_batch;
};
/// @brief The GEMM kernel device arguments.
template <typename AType = ck_tile::tuple<>, typename BType = ck_tile::tuple<>, typename DType = ck_tile::tuple<>>
struct GemmKernelArgs
{
/// @brief The A input tensor's pointer to device memory.
const AType* a_ptr;
/// @brief The B input tensor's pointer to device memory.
const BType* b_ptr;
/// @brief The B input tensor's pointer to device memory.
const DType ds_ptr;
/// @brief The E output tensor's pointer to device memory.
void* e_ptr;
/// @brief GEMM's M dimension size.
index_t M;
/// @brief GEMM's N dimension size.
index_t N;
/// @brief GEMM's K dimension size.
index_t K;
/// @brief The distance between consecutive elements of non-contiguous dimension
/// (in memory) of As tensor.
std::array<index_t, AType::size()> stride_As;
/// @brief The distance between consecutive elements of non-contiguous dimension
/// (in memory) of Bs tensor.
std::array<index_t, BType::size()> stride_Bs;
/// @brief The distance between consecutive elements of non-contiguous dimension
/// (in memory) of Ds tensor.
std::array<index_t, DType::size()> stride_Ds;
/// @brief The distance between consecutive elements of non-contiguous dimension
/// (in memory) of E tensor.
index_t stride_E;
index_t k_batch;
};
template <index_t NumATensor = 1, index_t NumBTensor = 1>
struct GemmHostArgs
{
@@ -130,7 +217,7 @@ struct GemmKernel
using AsDataType = remove_cvref_t<typename GemmPipeline::AsDataType>;
using BsDataType = remove_cvref_t<typename GemmPipeline::BsDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
@@ -267,7 +354,7 @@ struct GemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs<AsGridPointer, BsGridPointer>& kargs)
{
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value)
is_any_of<EDataType, fp16_t, bf16_t>::value)
{
if(kargs.k_batch != 1)
{
@@ -417,7 +504,7 @@ struct GemmKernel
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeGemmTensorViews(const AsGridPointer as_ptr,
const BsGridPointer bs_ptr,
CDataType* c_ptr,
EDataType* c_ptr,
const GemmKernelArgs<AsGridPointer, BsGridPointer>& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
@@ -668,7 +755,7 @@ struct GemmKernel
*/
CK_TILE_DEVICE static void RunGemm(const AsGridPointer as_ptr,
const BsGridPointer bs_ptr,
CDataType* c_ptr,
EDataType* c_ptr,
void* smem_ptr_0,
const GemmKernelArgs<AsGridPointer, BsGridPointer>& kargs,
const SplitKBatchOffset& splitk_batch_offset,
@@ -720,7 +807,7 @@ struct GemmKernel
*/
CK_TILE_DEVICE static void RunGemm2LDS(const AsGridPointer as_ptr,
const BsGridPointer bs_ptr,
CDataType* c_ptr,
EDataType* c_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const GemmKernelArgs<AsGridPointer, BsGridPointer>& kargs,
@@ -774,7 +861,7 @@ struct GemmKernel
bs_ptr(i) = static_cast<const BDataType*>(kargs.bs_ptr[i]) + splitk_batch_offset.b_k_split_offset[i];
});
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
EDataType* c_ptr = static_cast<EDataType*>(kargs.e_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
@@ -784,7 +871,7 @@ struct GemmKernel
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(as_ptr,
bs_ptr,
@@ -801,7 +888,7 @@ struct GemmKernel
{
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
RunGemm(as_ptr, bs_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}