[CK_TILE] Multiple-ABD GEMM example (#2788)

* Multi ABD - initial commit

* Clang-foramt fix

* block gemm, unify the name of CDataType

* Apply chnages to mem-pipeline

* Rollback prefix for DType and Layout

* Gemm Kernel Basic, rename

* WMMA config

* Grouped GEMM

* Clang-format

* Dropout, name

* Review v2

* Move element_wise fn to unnary, remov old ones fn

* clang-format

* Fix issue review

* WP operator adjust to universal gemm

* v2 prepare

* Remove unused comment

* Remove vectorsize

* Rollback

* Adjust pipeline for abd

* Shuffle argument

* CI-fail fix quant

* Fix ag_br pipeline

* Failing tests

* Typo

* Single argument support
This commit is contained in:
Mateusz Ozga
2025-09-19 01:14:11 +02:00
committed by GitHub
parent 14bbc545ea
commit 30ab1d6a71
41 changed files with 3603 additions and 552 deletions

View File

@@ -261,6 +261,81 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename AsDataType,
typename BsDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename AElementOp,
typename BElementOp,
typename CDElementOp,
typename ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>,
typename BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>,
typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
CK_TILE_HOST void
reference_gemm_multiple_abd(const std::array<HostTensor<ADataType>, AsDataType::size()>& as_m_k,
const std::array<HostTensor<BDataType>, BsDataType::size()>& bs_k_n,
const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
HostTensor<ADataType>& a_m_k,
HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const CDElementOp& acc_element_op = {})
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto as_m_k_tuple =
generate_tie([&](auto idx) -> auto& { return as_m_k[idx]; }, number<AsDataType::size()>{});
auto bs_k_n_tuple =
generate_tie([&](auto idx) -> auto& { return bs_k_n[idx]; }, number<BsDataType::size()>{});
auto ds_m_n_tuple =
generate_tie([&](auto idx) -> auto& { return ds_m_n[idx]; }, number<DsDataType::size()>{});
// Apply elementwise function to A
auto a_elementwise_fn = [&](auto i, auto j) {
ck_tile::apply([&](auto&&... t) { a_element_op(a_m_k(i, j), t(i, j)...); }, as_m_k_tuple);
};
make_ParallelTensorFunctor(a_elementwise_fn, M, K)(std::thread::hardware_concurrency());
// Apply elementwise function to B
auto b_elementwise_fn = [&](auto i, auto j) {
ck_tile::apply([&](auto&&... t) { b_element_op(b_k_n(i, j), t(i, j)...); }, bs_k_n_tuple);
};
make_ParallelTensorFunctor(b_elementwise_fn, K, N)(std::thread::hardware_concurrency());
auto f_mk_kn_mn = [&](auto m, auto n) {
AccDataType v_acc = 0;
for(std::size_t k = 0; k < K; ++k)
{
ADataType v_a = a_m_k(m, k);
BDataType v_b = b_k_n(k, n);
v_acc +=
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
}
CDataType v_c = 0;
ck_tile::apply(
[&](auto&&... t) {
acc_element_op(v_c,
ck_tile::type_convert<float>(v_acc),
ck_tile::type_convert<float>(t(m, n))...);
},
ds_m_n_tuple);
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
};
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType,
typename BDataType,
typename DsDataType,