mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] Multiple-ABD GEMM example (#2788)
* Multi ABD - initial commit * Clang-foramt fix * block gemm, unify the name of CDataType * Apply chnages to mem-pipeline * Rollback prefix for DType and Layout * Gemm Kernel Basic, rename * WMMA config * Grouped GEMM * Clang-format * Dropout, name * Review v2 * Move element_wise fn to unnary, remov old ones fn * clang-format * Fix issue review * WP operator adjust to universal gemm * v2 prepare * Remove unused comment * Remove vectorsize * Rollback * Adjust pipeline for abd * Shuffle argument * CI-fail fix quant * Fix ag_br pipeline * Failing tests * Typo * Single argument support
This commit is contained in:
@@ -261,6 +261,81 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDElementOp,
|
||||
typename ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>,
|
||||
typename BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>,
|
||||
typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
|
||||
CK_TILE_HOST void
|
||||
reference_gemm_multiple_abd(const std::array<HostTensor<ADataType>, AsDataType::size()>& as_m_k,
|
||||
const std::array<HostTensor<BDataType>, BsDataType::size()>& bs_k_n,
|
||||
const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
|
||||
HostTensor<ADataType>& a_m_k,
|
||||
HostTensor<BDataType>& b_k_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const CDElementOp& acc_element_op = {})
|
||||
{
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
auto as_m_k_tuple =
|
||||
generate_tie([&](auto idx) -> auto& { return as_m_k[idx]; }, number<AsDataType::size()>{});
|
||||
|
||||
auto bs_k_n_tuple =
|
||||
generate_tie([&](auto idx) -> auto& { return bs_k_n[idx]; }, number<BsDataType::size()>{});
|
||||
|
||||
auto ds_m_n_tuple =
|
||||
generate_tie([&](auto idx) -> auto& { return ds_m_n[idx]; }, number<DsDataType::size()>{});
|
||||
|
||||
// Apply elementwise function to A
|
||||
auto a_elementwise_fn = [&](auto i, auto j) {
|
||||
ck_tile::apply([&](auto&&... t) { a_element_op(a_m_k(i, j), t(i, j)...); }, as_m_k_tuple);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(a_elementwise_fn, M, K)(std::thread::hardware_concurrency());
|
||||
|
||||
// Apply elementwise function to B
|
||||
auto b_elementwise_fn = [&](auto i, auto j) {
|
||||
ck_tile::apply([&](auto&&... t) { b_element_op(b_k_n(i, j), t(i, j)...); }, bs_k_n_tuple);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(b_elementwise_fn, K, N)(std::thread::hardware_concurrency());
|
||||
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0;
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_m_k(m, k);
|
||||
BDataType v_b = b_k_n(k, n);
|
||||
v_acc +=
|
||||
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
CDataType v_c = 0;
|
||||
|
||||
ck_tile::apply(
|
||||
[&](auto&&... t) {
|
||||
acc_element_op(v_c,
|
||||
ck_tile::type_convert<float>(v_acc),
|
||||
ck_tile::type_convert<float>(t(m, n))...);
|
||||
},
|
||||
ds_m_n_tuple);
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
|
||||
Reference in New Issue
Block a user